Skip to content

Commit

Permalink
HSEARCH-5133 Test all types with metric aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
fax4ever committed Aug 26, 2024
1 parent e28a110 commit 532dc54
Show file tree
Hide file tree
Showing 3 changed files with 281 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.search.integrationtest.backend.tck.search.aggregation;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;

import org.hibernate.search.engine.backend.document.model.dsl.IndexSchemaElement;
import org.hibernate.search.engine.backend.types.Aggregable;
import org.hibernate.search.integrationtest.backend.tck.testsupport.model.singlefield.SingleFieldIndexBinding;
import org.hibernate.search.integrationtest.backend.tck.testsupport.operations.MetricAggregationsTestCase;
import org.hibernate.search.integrationtest.backend.tck.testsupport.types.FieldTypeDescriptor;
import org.hibernate.search.integrationtest.backend.tck.testsupport.types.StandardFieldTypeDescriptor;
import org.hibernate.search.integrationtest.backend.tck.testsupport.util.extension.SearchSetupHelper;
import org.hibernate.search.util.impl.integrationtest.mapper.stub.BulkIndexer;
import org.hibernate.search.util.impl.integrationtest.mapper.stub.SimpleMappedIndex;
import org.hibernate.search.util.impl.integrationtest.mapper.stub.StubMappingScope;

import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

public class MetricFieldAggregationsIT {

private static final Set<StandardFieldTypeDescriptor<?>> supportedFieldTypes = new LinkedHashSet<>();
private static final List<MetricAggregationsTestCase<?>> testCases = new ArrayList<>();
private static final List<Arguments> parameters = new ArrayList<>();

static {
for ( StandardFieldTypeDescriptor<?> typeDescriptor : FieldTypeDescriptor.getAllStandard() ) {
MetricAggregationsTestCase<?> scenario = new MetricAggregationsTestCase<>( typeDescriptor );
if ( !scenario.supported() ) {
continue;
}
testCases.add( scenario );
supportedFieldTypes.add( typeDescriptor );
parameters.add( Arguments.of( scenario ) );
}
}

public static List<? extends Arguments> params() {
return parameters;
}

@RegisterExtension
public static final SearchSetupHelper setupHelper = SearchSetupHelper.create();

private static final Function<IndexSchemaElement, SingleFieldIndexBinding> bindingFactory =
root -> SingleFieldIndexBinding.create( root, supportedFieldTypes, c -> c.aggregable( Aggregable.YES ) );
private static final SimpleMappedIndex<SingleFieldIndexBinding> mainIndex =
SimpleMappedIndex.of( bindingFactory ).name( "main" );

@BeforeAll
static void setup() {
int expectedDocuments = 0;

setupHelper.start().withIndexes( mainIndex ).setup();
BulkIndexer indexer = mainIndex.bulkIndexer();
for ( MetricAggregationsTestCase<?> scenario : testCases ) {
expectedDocuments += scenario.contribute( indexer, mainIndex.binding() );
}
indexer.join();

long createdDocuments = mainIndex.createScope().query().where( f -> f.matchAll() )
.totalHitCountThreshold( expectedDocuments )
.toQuery().fetch( 0 ).total().hitCountLowerBound();
assertThat( createdDocuments ).isEqualTo( expectedDocuments );
}

@ParameterizedTest(name = "{0}")
@MethodSource("params")
public void test(MetricAggregationsTestCase<?> testCase) {
StubMappingScope scope = mainIndex.createScope();
MetricAggregationsTestCase.Result<?> result = testCase.testMetricsAggregation( scope, mainIndex.binding() );
if ( result.expectedSum() != null ) {
assertThat( result.computedSum() ).isEqualTo( result.expectedSum() );
}
assertThat( result.computedMin() ).isEqualTo( result.expectedMin() );
assertThat( result.computedMax() ).isEqualTo( result.expectedMax() );
assertThat( result.computedCount() ).isEqualTo( result.expectedCount() );
assertThat( result.computedCountDistinct() ).isEqualTo( result.expectedCountDistinct() );
assertThat( result.computedAvg() ).isEqualTo( result.expectedAvg() );
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.search.integrationtest.backend.tck.testsupport.operations;

import static org.hibernate.search.util.impl.integrationtest.mapper.stub.StubMapperUtils.documentProvider;

import java.util.List;
import java.util.Locale;
import java.util.StringJoiner;

import org.hibernate.search.engine.backend.common.DocumentReference;
import org.hibernate.search.engine.search.aggregation.AggregationKey;
import org.hibernate.search.engine.search.predicate.dsl.SearchPredicateFactory;
import org.hibernate.search.engine.search.query.SearchQuery;
import org.hibernate.search.engine.search.query.SearchResult;
import org.hibernate.search.integrationtest.backend.tck.testsupport.model.singlefield.SingleFieldIndexBinding;
import org.hibernate.search.integrationtest.backend.tck.testsupport.types.FieldTypeDescriptor;
import org.hibernate.search.integrationtest.backend.tck.testsupport.types.values.MetricAggregationsValues;
import org.hibernate.search.integrationtest.backend.tck.testsupport.util.IndexFieldLocation;
import org.hibernate.search.integrationtest.backend.tck.testsupport.util.IndexFieldValueCardinality;
import org.hibernate.search.integrationtest.backend.tck.testsupport.util.TestedFieldStructure;
import org.hibernate.search.util.impl.integrationtest.mapper.stub.BulkIndexer;
import org.hibernate.search.util.impl.integrationtest.mapper.stub.StubMappingScope;

/**
* Denotes a metric aggregations test case for a particular {@link FieldTypeDescriptor}.
*/
public class MetricAggregationsTestCase<F> {

private final FieldTypeDescriptor<F, ?> typeDescriptor;
private final boolean supported;
private final MetricAggregationsValues<F> metricAggregationsValues;

public MetricAggregationsTestCase(FieldTypeDescriptor<F, ?> typeDescriptor) {
this.typeDescriptor = typeDescriptor;
metricAggregationsValues = typeDescriptor.metricAggregationsValues();
this.supported = metricAggregationsValues != null;
}

public boolean supported() {
return supported;
}

public int contribute(BulkIndexer indexer, SingleFieldIndexBinding binding) {
int i = 0;
for ( F value : metricAggregationsValues.values() ) {
String uniqueName = typeDescriptor.getUniqueName();
String keyA = String.format( Locale.ROOT, "%03d_ROOT_%s", ++i, uniqueName );
String keyB = String.format( Locale.ROOT, "%03d_NEST_%s", i, uniqueName );
String keyC = String.format( Locale.ROOT, "%03d_FLAT_%s", i, uniqueName );
indexer.add( documentProvider( keyA, document -> binding.initSingleValued( typeDescriptor,
IndexFieldLocation.ROOT, document, value ) ) );
indexer.add( documentProvider( keyB, document -> binding.initSingleValued(
typeDescriptor, IndexFieldLocation.IN_NESTED, document, value ) ) );
indexer.add( documentProvider( keyC, document -> binding.initSingleValued(
typeDescriptor, IndexFieldLocation.IN_FLATTENED, document, value ) ) );
}
return metricAggregationsValues.values().size() * 3;
}

public Result<F> testMetricsAggregation(StubMappingScope scope, SingleFieldIndexBinding binding) {
InternalResult<F> result = new InternalResult<>();
String fieldPath = binding.getFieldPath( TestedFieldStructure.of(
IndexFieldLocation.ROOT, IndexFieldValueCardinality.SINGLE_VALUED ), typeDescriptor );

SearchQuery<DocumentReference> query = scope.query().where( SearchPredicateFactory::matchAll )
.aggregation( result.sumKey, f -> f.sum().field( fieldPath, typeDescriptor.getJavaType() ) )
.aggregation( result.minKey, f -> f.min().field( fieldPath, typeDescriptor.getJavaType() ) )
.aggregation( result.maxKey, f -> f.max().field( fieldPath, typeDescriptor.getJavaType() ) )
.aggregation( result.countKey, f -> f.count().field( fieldPath ) )
.aggregation( result.countDistinctKey, f -> f.countDistinct().field( fieldPath ) )
.aggregation( result.avgKey, f -> f.avg().field( fieldPath, typeDescriptor.getJavaType() ) )
.toQuery();
result.apply( query );
return new Result<>( typeDescriptor.getJavaType(), metricAggregationsValues, result );
}

@Override
public String toString() {
return "Case{" + typeDescriptor + '}';
}

public static class Result<F> {
private final Class<F> javaType;
private final MetricAggregationsValues<F> metricAggregationsValues;
private final InternalResult<F> result;

public Result(Class<F> javaType, MetricAggregationsValues<F> metricAggregationsValues,
InternalResult<F> result) {
this.javaType = javaType;
this.metricAggregationsValues = metricAggregationsValues;
this.result = result;
}

public List<F> values() {
return metricAggregationsValues.values();
}

public F expectedSum() {
return metricAggregationsValues.sum();
}

public F expectedMin() {
return metricAggregationsValues.min();
}

public F expectedMax() {
return metricAggregationsValues.max();
}

public Long expectedCount() {
return metricAggregationsValues.count();
}

public Long expectedCountDistinct() {
return metricAggregationsValues.countDistinct();
}

public F expectedAvg() {
return metricAggregationsValues.avg();
}

public F computedSum() {
return result.sum;
}

public F computedMax() {
return result.max;
}

public F computedMin() {
return result.min;
}

public Long computedCount() {
return result.count;
}

public Long computedCountDistinct() {
return result.countDistinct;
}

public F computedAvg() {
return result.avg;
}

@Override
public String toString() {
return new StringJoiner( ", ", Result.class.getSimpleName() + "[", "]" )
.add( "javaType=" + javaType )
.add( "values=" + values() )
.toString();
}
}

private static class InternalResult<F> {
AggregationKey<F> sumKey = AggregationKey.of( "sum" );
AggregationKey<F> minKey = AggregationKey.of( "min" );
AggregationKey<F> maxKey = AggregationKey.of( "max" );
AggregationKey<Long> countKey = AggregationKey.of( "count" );
AggregationKey<Long> countDistinctKey = AggregationKey.of( "countDistinct" );
AggregationKey<F> avgKey = AggregationKey.of( "avg" );

F sum;
F min;
F max;
Long count;
Long countDistinct;
F avg;

void apply(SearchQuery<DocumentReference> query) {
SearchResult<DocumentReference> result = query.fetch( 0 );
sum = result.aggregation( sumKey );
min = result.aggregation( minKey );
max = result.aggregation( maxKey );
count = result.aggregation( countKey );
countDistinct = result.aggregation( countDistinctKey );
avg = result.aggregation( avgKey );
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ public static List<TestedFieldStructure> all() {
return ALL;
}

public static TestedFieldStructure of(IndexFieldLocation location, IndexFieldValueCardinality cardinality) {
return new TestedFieldStructure( location, cardinality );
}

private static final List<TestedFieldStructure> ALL;

static {
List<TestedFieldStructure> values = new ArrayList<>();
for ( IndexFieldLocation location : IndexFieldLocation.values() ) {
Expand Down

0 comments on commit 532dc54

Please sign in to comment.