Skip to content

Commit

Permalink
OPIK-904: Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagohora committed Jan 30, 2025
1 parent a1eecf2 commit 56099b2
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ public Response getProjectStats(

log.info("Find projects stats by '{}' on workspaceId '{}'", criteria, workspaceId);
ProjectStatsSummary projectStatisticsSummary = projectService.getStats(page, size, criteria, sortingFields);
log.info("Found projects stats by '{}', count '{}' on workspaceId '{}'", criteria, projectStatisticsSummary.content().size(), workspaceId);
log.info("Found projects stats by '{}', count '{}' on workspaceId '{}'", criteria,
projectStatisticsSummary.content().size(), workspaceId);

return Response.ok().entity(projectStatisticsSummary).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ public interface ProjectService {

void recordLastUpdatedTrace(String workspaceId, Collection<ProjectIdLastUpdated> lastUpdatedTraces);

ProjectStatsSummary getStats(int page, int size, @NonNull ProjectCriteria criteria, @NonNull List<SortingField> sortingFields);
ProjectStatsSummary getStats(int page, int size, @NonNull ProjectCriteria criteria,
@NonNull List<SortingField> sortingFields);
}

@Slf4j
Expand Down Expand Up @@ -487,7 +488,8 @@ public Project retrieveByName(@NonNull String projectName) {
.feedbackScores(StatsMapper.getStatsFeedbackScores(projectStats.get(project.id())))
.usage(StatsMapper.getStatsUsage(projectStats.get(project.id())))
.duration(StatsMapper.getStatsDuration(projectStats.get(project.id())))
.totalEstimatedCost(StatsMapper.getStatsTotalEstimatedCost(projectStats.get(project.id())))
.totalEstimatedCost(
StatsMapper.getStatsTotalEstimatedCost(projectStats.get(project.id())))
.build();
})
.orElseThrow(this::createNotFoundError);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ AND notEquals(s.start_time, toDateTime64('1970-01-01 00:00:00.000', 9)),
FROM (
SELECT
*,
row_number() OVER (PARTITION BY id ORDER BY last_updated_at DESC) AS latest\s
row_number() OVER (PARTITION BY id ORDER BY last_updated_at DESC) AS latest
FROM spans
WHERE id IN (SELECT id FROM span_ids)
) AS s
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class ProjectsResourceTest {
public static final String URL_TEMPLATE_TRACE = "%s/v1/private/traces";
public static final String[] IGNORED_FIELDS = {"createdBy", "lastUpdatedBy", "createdAt", "lastUpdatedAt",
"lastUpdatedTraceAt", "feedbackScores", "duration", "totalEstimatedCost", "usage"};
public static final String[] IGNORED_FIELD_MIN = {"createdBy", "lastUpdatedBy", "createdAt", "lastUpdatedAt",
"lastUpdatedTraceAt"};

private static final String API_KEY = UUID.randomUUID().toString();
private static final String USER = UUID.randomUUID().toString();
Expand Down Expand Up @@ -582,23 +584,26 @@ void getProjectById__whenProjectExists__thenReturnProject() {

var id = createProject(project, apiKey, workspaceName);

project = buildProjectStats(project.toBuilder().id(id).build(), apiKey, workspaceName);

try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI))
.path("retrieve")
.request()
.header(HttpHeaders.AUTHORIZATION, apiKey)
.header(WORKSPACE_HEADER, workspaceName)
.post(Entity.json(ProjectRetrieve.builder().name(project.name()).build()))) {

assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200);
assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_OK);
assertThat(actualResponse.hasEntity()).isTrue();

var actualEntity = actualResponse.readEntity(Project.class);
assertThat(actualEntity)
.usingRecursiveComparison()
.ignoringFields(IGNORED_FIELDS)
.isEqualTo(project.toBuilder()
.id(id)
.build());
.ignoringFields(IGNORED_FIELD_MIN)
.ignoringCollectionOrder()
.withComparatorForType(StatsUtils::bigDecimalComparator, BigDecimal.class)
.withComparatorForFields(StatsUtils::closeToEpsilonComparator, "totalEstimatedCost")
.isEqualTo(project);
}
}

Expand Down Expand Up @@ -1153,7 +1158,8 @@ void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsage__thenReturnPr

Comparator<Project> comparator = Comparator.comparing(Project::id).reversed();

List<ProjectStatsSummaryItem> expectedProjectStats = getProjectStatsSummaryItems(apiKey, workspaceName, comparator);
List<ProjectStatsSummaryItem> expectedProjectStats = getProjectStatsSummaryItems(apiKey, workspaceName,
comparator);

var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI))
.path("/stats")
Expand All @@ -1162,13 +1168,14 @@ void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsage__thenReturnPr
.header(WORKSPACE_HEADER, workspaceName)
.get();

var actualEntity = actualResponse.readEntity(Project.ProjectPage.class);
var actualEntity = actualResponse.readEntity(ProjectStatsSummary.class);
assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(org.apache.http.HttpStatus.SC_OK);

assertThat(expectedProjectStats).hasSameSizeAs(actualEntity.content());

assertThat(actualEntity.content())
.usingRecursiveComparison()
.ignoringCollectionOrder()
.withComparatorForType(StatsUtils::bigDecimalComparator, BigDecimal.class)
.withComparatorForFields(StatsUtils::closeToEpsilonComparator, "totalEstimatedCost")
.isEqualTo(expectedProjectStats);
Expand All @@ -1185,7 +1192,8 @@ void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsageSortedLastTrac

Comparator<Project> comparator = Comparator.comparing(Project::lastUpdatedTraceAt).reversed();

List<ProjectStatsSummaryItem> expectedProjectStats = getProjectStatsSummaryItems(apiKey, workspaceName, comparator);
List<ProjectStatsSummaryItem> expectedProjectStats = getProjectStatsSummaryItems(apiKey, workspaceName,
comparator);

var sorting = List.of(SortingField.builder()
.field(SortableFields.LAST_UPDATED_TRACE_AT)
Expand All @@ -1209,12 +1217,14 @@ void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsageSortedLastTrac

assertThat(actualEntity.content())
.usingRecursiveComparison()
.ignoringCollectionOrder()
.withComparatorForType(StatsUtils::bigDecimalComparator, BigDecimal.class)
.withComparatorForFields(StatsUtils::closeToEpsilonComparator, "totalEstimatedCost")
.isEqualTo(expectedProjectStats);
}

private List<ProjectStatsSummaryItem> getProjectStatsSummaryItems(String apiKey, String workspaceName, Comparator<Project> comparing) {
private List<ProjectStatsSummaryItem> getProjectStatsSummaryItems(String apiKey, String workspaceName,
Comparator<Project> comparing) {
var projects = PodamFactoryUtils.manufacturePojoList(factory, Project.class)
.parallelStream()
.map(project -> project.toBuilder()
Expand Down Expand Up @@ -1266,6 +1276,7 @@ void getProjects__whenProjectsHasNoTracesSpansFeedbackScoresAndUsage__thenReturn
.totalEstimatedCost(null)
.usage(null)
.feedbackScores(null)
.projectId(project.id())
.build())
.sorted(Comparator.comparing(ProjectStatsSummaryItem::projectId).reversed())
.toList();
Expand All @@ -1290,107 +1301,6 @@ void getProjects__whenProjectsHasNoTracesSpansFeedbackScoresAndUsage__thenReturn
.isEqualTo(expectedProjectStats);
}

private Project buildProjectStats(Project project, String apiKey, String workspaceName) {
var traces = PodamFactoryUtils.manufacturePojoList(factory, Trace.class).stream()
.map(trace -> {
Instant startTime = Instant.now();
Instant endTime = startTime.plusMillis(PodamUtils.getIntegerInRange(1, 1000));
return trace.toBuilder()
.projectName(project.name())
.startTime(startTime)
.endTime(endTime)
.duration(DurationUtils.getDurationInMillisWithSubMilliPrecision(startTime, endTime))
.build();
})
.toList();

traceResourceClient.batchCreateTraces(traces, apiKey, workspaceName);

List<FeedbackScoreBatchItem> scores = PodamFactoryUtils.manufacturePojoList(factory,
FeedbackScoreBatchItem.class);

traces = traces.stream().map(trace -> {
List<Span> spans = PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream()
.map(span -> span.toBuilder()
.usage(spanResourceClient.getTokenUsage())
.model(spanResourceClient.randomModel().toString())
.traceId(trace.id())
.projectName(trace.projectName())
.totalEstimatedCost(null)
.build())
.toList();

spanResourceClient.batchCreateSpans(spans, apiKey, workspaceName);

List<FeedbackScoreBatchItem> feedbackScores = scores.stream()
.map(feedbackScore -> feedbackScore.toBuilder()
.projectId(project.id())
.projectName(project.name())
.id(trace.id())
.build())
.toList();

traceResourceClient.feedbackScores(feedbackScores, apiKey, workspaceName);

return trace.toBuilder()
.feedbackScores(
feedbackScores.stream()
.map(score -> FeedbackScore.builder()
.value(score.value())
.name(score.name())
.build())
.toList())
.usage(StatsUtils.aggregateSpansUsage(spans))
.totalEstimatedCost(StatsUtils.aggregateSpansCost(spans))
.build();
}).toList();

List<BigDecimal> durations = StatsUtils.calculateQuantiles(
traces.stream()
.map(Trace::duration)
.toList(),
List.of(0.5, 0.90, 0.99));

return project.toBuilder()
.duration(new ProjectStats.PercentageValues(durations.get(0), durations.get(1), durations.get(2)))
.totalEstimatedCost(getTotalEstimatedCost(traces))
.usage(traces.stream()
.map(Trace::usage)
.flatMap(usage -> usage.entrySet().stream())
.collect(groupingBy(Map.Entry::getKey, averagingDouble(Map.Entry::getValue))))
.feedbackScores(getScoreAverages(traces))
.lastUpdatedTraceAt(traces.stream().map(Trace::lastUpdatedAt).max(Instant::compareTo).orElse(null))
.build();
}

private List<FeedbackScoreAverage> getScoreAverages(List<Trace> traces) {
return traces.stream()
.map(Trace::feedbackScores)
.flatMap(List::stream)
.collect(groupingBy(FeedbackScore::name,
BigDecimalCollectors.averagingBigDecimal(FeedbackScore::value)))
.entrySet()
.stream()
.map(entry -> FeedbackScoreAverage.builder()
.name(entry.getKey())
.value(entry.getValue())
.build())
.toList();
}

private double getTotalEstimatedCost(List<Trace> traces) {
long count = traces.stream()
.map(Trace::totalEstimatedCost)
.filter(Objects::nonNull)
.filter(cost -> cost.compareTo(BigDecimal.ZERO) > 0)
.count();

return traces.stream()
.map(Trace::totalEstimatedCost)
.reduce(BigDecimal.ZERO, BigDecimal::add)
.divide(BigDecimal.valueOf(count), ValidationUtils.SCALE, RoundingMode.HALF_UP).doubleValue();
}

@Test
@DisplayName("when projects is with traces created in batch, then return project with last updated trace at")
void getProjects__whenProjectsHasTracesBatch__thenReturnProjectWithLastUpdatedTraceAt() {
Expand Down Expand Up @@ -1507,6 +1417,107 @@ private void assertAllProjectsHavePersistedLastTraceAt(String workspaceId, List<
}
}

private Project buildProjectStats(Project project, String apiKey, String workspaceName) {
var traces = PodamFactoryUtils.manufacturePojoList(factory, Trace.class).stream()
.map(trace -> {
Instant startTime = Instant.now();
Instant endTime = startTime.plusMillis(PodamUtils.getIntegerInRange(1, 1000));
return trace.toBuilder()
.projectName(project.name())
.startTime(startTime)
.endTime(endTime)
.duration(DurationUtils.getDurationInMillisWithSubMilliPrecision(startTime, endTime))
.build();
})
.toList();

traceResourceClient.batchCreateTraces(traces, apiKey, workspaceName);

List<FeedbackScoreBatchItem> scores = PodamFactoryUtils.manufacturePojoList(factory,
FeedbackScoreBatchItem.class);

traces = traces.stream().map(trace -> {
List<Span> spans = PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream()
.map(span -> span.toBuilder()
.usage(spanResourceClient.getTokenUsage())
.model(spanResourceClient.randomModel().toString())
.traceId(trace.id())
.projectName(trace.projectName())
.totalEstimatedCost(null)
.build())
.toList();

spanResourceClient.batchCreateSpans(spans, apiKey, workspaceName);

List<FeedbackScoreBatchItem> feedbackScores = scores.stream()
.map(feedbackScore -> feedbackScore.toBuilder()
.projectId(project.id())
.projectName(project.name())
.id(trace.id())
.build())
.toList();

traceResourceClient.feedbackScores(feedbackScores, apiKey, workspaceName);

return trace.toBuilder()
.feedbackScores(
feedbackScores.stream()
.map(score -> FeedbackScore.builder()
.value(score.value())
.name(score.name())
.build())
.toList())
.usage(StatsUtils.aggregateSpansUsage(spans))
.totalEstimatedCost(StatsUtils.aggregateSpansCost(spans))
.build();
}).toList();

List<BigDecimal> durations = StatsUtils.calculateQuantiles(
traces.stream()
.map(Trace::duration)
.toList(),
List.of(0.5, 0.90, 0.99));

return project.toBuilder()
.duration(new ProjectStats.PercentageValues(durations.get(0), durations.get(1), durations.get(2)))
.totalEstimatedCost(getTotalEstimatedCost(traces))
.usage(traces.stream()
.map(Trace::usage)
.flatMap(usage -> usage.entrySet().stream())
.collect(groupingBy(Map.Entry::getKey, averagingDouble(Map.Entry::getValue))))
.feedbackScores(getScoreAverages(traces))
.lastUpdatedTraceAt(traces.stream().map(Trace::lastUpdatedAt).max(Instant::compareTo).orElse(null))
.build();
}

private List<FeedbackScoreAverage> getScoreAverages(List<Trace> traces) {
return traces.stream()
.map(Trace::feedbackScores)
.flatMap(List::stream)
.collect(groupingBy(FeedbackScore::name,
BigDecimalCollectors.averagingBigDecimal(FeedbackScore::value)))
.entrySet()
.stream()
.map(entry -> FeedbackScoreAverage.builder()
.name(entry.getKey())
.value(entry.getValue())
.build())
.toList();
}

private double getTotalEstimatedCost(List<Trace> traces) {
long count = traces.stream()
.map(Trace::totalEstimatedCost)
.filter(Objects::nonNull)
.filter(cost -> cost.compareTo(BigDecimal.ZERO) > 0)
.count();

return traces.stream()
.map(Trace::totalEstimatedCost)
.reduce(BigDecimal.ZERO, BigDecimal::add)
.divide(BigDecimal.valueOf(count), ValidationUtils.SCALE, RoundingMode.HALF_UP).doubleValue();
}

@Nested
@DisplayName("Get: {id}")
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
Expand Down

0 comments on commit 56099b2

Please sign in to comment.