Skip to content

Commit

Permalink
Configurable sketch accuracy in merge rollup task (#14373)
Browse files Browse the repository at this point in the history
* Configurable sketch accuracy in merge rollup task

* Run mvn spotless:apply
  • Loading branch information
davecromberge authored Dec 10, 2024
1 parent 9f2a727 commit 442c0fc
Show file tree
Hide file tree
Showing 19 changed files with 185 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ public static abstract class MergeTask {
// Merge config
public static final String MERGE_TYPE_KEY = "mergeType";
public static final String AGGREGATION_TYPE_KEY_SUFFIX = ".aggregationType";
public static final String AGGREGATION_FUNCTION_PARAMETERS_PREFIX = "aggregationFunctionParameters.";
public static final String MODE = "mode";
public static final String PROCESS_FROM_WATERMARK_MODE = "processFromWatermark";
public static final String PROCESS_ALL_MODE = "processAll";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.pinot.core.segment.processing.aggregator;

import java.util.Map;
import org.apache.datasketches.cpc.CpcSketch;
import org.apache.datasketches.cpc.CpcUnion;
import org.apache.pinot.core.common.ObjectSerDeUtils;
Expand All @@ -30,7 +31,7 @@ public DistinctCountCPCSketchAggregator() {
}

@Override
public Object aggregate(Object value1, Object value2) {
public Object aggregate(Object value1, Object value2, Map<String, String> functionParameters) {
CpcSketch first = ObjectSerDeUtils.DATA_SKETCH_CPC_SER_DE.deserialize((byte[]) value1);
CpcSketch second = ObjectSerDeUtils.DATA_SKETCH_CPC_SER_DE.deserialize((byte[]) value2);
CpcSketch result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@

import com.clearspring.analytics.stream.cardinality.CardinalityMergeException;
import com.clearspring.analytics.stream.cardinality.HyperLogLog;
import java.util.Map;
import org.apache.pinot.core.common.ObjectSerDeUtils;


public class DistinctCountHLLAggregator implements ValueAggregator {
@Override
public Object aggregate(Object value1, Object value2) {
public Object aggregate(Object value1, Object value2, Map<String, String> functionParameters) {
try {
HyperLogLog first = ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize((byte[]) value1);
HyperLogLog second = ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize((byte[]) value2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,38 @@
*/
package org.apache.pinot.core.segment.processing.aggregator;

import java.util.Map;
import org.apache.datasketches.theta.Sketch;
import org.apache.datasketches.theta.Union;
import org.apache.pinot.core.common.ObjectSerDeUtils;
import org.apache.pinot.segment.spi.Constants;
import org.apache.pinot.spi.utils.CommonConstants;


public class DistinctCountThetaSketchAggregator implements ValueAggregator {

private final Union _union;

public DistinctCountThetaSketchAggregator() {
// TODO: Handle configurable nominal entries
_union = Union.builder().setNominalEntries(CommonConstants.Helix.DEFAULT_THETA_SKETCH_NOMINAL_ENTRIES).buildUnion();
}

@Override
public Object aggregate(Object value1, Object value2) {
public Object aggregate(Object value1, Object value2, Map<String, String> functionParameters) {
String nominalEntriesParam = functionParameters.get(Constants.THETA_TUPLE_SKETCH_NOMINAL_ENTRIES);

int sketchNominalEntries;

// Check if nominal entries values match
if (nominalEntriesParam != null) {
sketchNominalEntries = Integer.parseInt(nominalEntriesParam);
} else {
// If the functionParameters don't have an explicit nominal entries value set,
// use the default value for nominal entries
sketchNominalEntries = CommonConstants.Helix.DEFAULT_THETA_SKETCH_NOMINAL_ENTRIES;
}

Union union = Union.builder().setNominalEntries(sketchNominalEntries).buildUnion();
Sketch first = ObjectSerDeUtils.DATA_SKETCH_THETA_SER_DE.deserialize((byte[]) value1);
Sketch second = ObjectSerDeUtils.DATA_SKETCH_THETA_SER_DE.deserialize((byte[]) value2);
Sketch result = _union.union(first, second);
Sketch result = union.union(first, second);
return ObjectSerDeUtils.DATA_SKETCH_THETA_SER_DE.serialize(result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
package org.apache.pinot.core.segment.processing.aggregator;

import com.dynatrace.hash4j.distinctcount.UltraLogLog;
import java.util.Map;
import org.apache.pinot.core.common.ObjectSerDeUtils;


public class DistinctCountULLAggregator implements ValueAggregator {
@Override
public Object aggregate(Object value1, Object value2) {
public Object aggregate(Object value1, Object value2, Map<String, String> functionParameters) {
UltraLogLog first = ObjectSerDeUtils.ULTRA_LOG_LOG_OBJECT_SER_DE.deserialize((byte[]) value1);
UltraLogLog second = ObjectSerDeUtils.ULTRA_LOG_LOG_OBJECT_SER_DE.deserialize((byte[]) value2);
// add to the one with a larger P and return that
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
*/
package org.apache.pinot.core.segment.processing.aggregator;

import java.util.Map;
import org.apache.datasketches.tuple.Sketch;
import org.apache.datasketches.tuple.Union;
import org.apache.datasketches.tuple.aninteger.IntegerSummary;
import org.apache.datasketches.tuple.aninteger.IntegerSummarySetOperations;
import org.apache.pinot.core.common.ObjectSerDeUtils;
import org.apache.pinot.segment.spi.Constants;
import org.apache.pinot.spi.utils.CommonConstants;


public class IntegerTupleSketchAggregator implements ValueAggregator {
Expand All @@ -33,10 +36,24 @@ public IntegerTupleSketchAggregator(IntegerSummary.Mode mode) {
}

@Override
public Object aggregate(Object value1, Object value2) {
public Object aggregate(Object value1, Object value2, Map<String, String> functionParameters) {
String nominalEntriesParam = functionParameters.get(Constants.THETA_TUPLE_SKETCH_NOMINAL_ENTRIES);

int sketchNominalEntries;

// Check if nominal entries values match
if (nominalEntriesParam != null) {
sketchNominalEntries = Integer.parseInt(nominalEntriesParam);
} else {
// If the functionParameters don't have an explicit nominal entries value set,
// use the default value for nominal entries
sketchNominalEntries = (int) Math.pow(2, CommonConstants.Helix.DEFAULT_TUPLE_SKETCH_LGK);
}

Sketch<IntegerSummary> first = ObjectSerDeUtils.DATA_SKETCH_INT_TUPLE_SER_DE.deserialize((byte[]) value1);
Sketch<IntegerSummary> second = ObjectSerDeUtils.DATA_SKETCH_INT_TUPLE_SER_DE.deserialize((byte[]) value2);
Sketch<IntegerSummary> result = new Union<>(new IntegerSummarySetOperations(_mode, _mode)).union(first, second);
Sketch<IntegerSummary> result =
new Union<>(sketchNominalEntries, new IntegerSummarySetOperations(_mode, _mode)).union(first, second);
return ObjectSerDeUtils.DATA_SKETCH_INT_TUPLE_SER_DE.serialize(result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.pinot.core.segment.processing.aggregator;

import java.util.Map;
import org.apache.pinot.spi.data.FieldSpec;


Expand All @@ -33,7 +34,7 @@ public MaxValueAggregator(FieldSpec.DataType dataType) {
}

@Override
public Object aggregate(Object value1, Object value2) {
public Object aggregate(Object value1, Object value2, Map<String, String> functionParameters) {
Object result;
switch (_dataType) {
case INT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.pinot.core.segment.processing.aggregator;

import java.util.Map;
import org.apache.pinot.spi.data.FieldSpec;


Expand All @@ -33,7 +34,7 @@ public MinValueAggregator(FieldSpec.DataType dataType) {
}

@Override
public Object aggregate(Object value1, Object value2) {
public Object aggregate(Object value1, Object value2, Map<String, String> functionParameters) {
Object result;
switch (_dataType) {
case INT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.pinot.core.segment.processing.aggregator;

import java.util.Map;
import org.apache.pinot.spi.data.FieldSpec;


Expand All @@ -33,7 +34,7 @@ public SumValueAggregator(FieldSpec.DataType dataType) {
}

@Override
public Object aggregate(Object value1, Object value2) {
public Object aggregate(Object value1, Object value2, Map<String, String> functionParameters) {
Object result;
switch (_dataType) {
case INT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
*/
package org.apache.pinot.core.segment.processing.aggregator;

import java.util.Map;


/**
* Interface for value aggregator
*/
Expand All @@ -27,5 +30,5 @@ public interface ValueAggregator {
* Given two values, return the aggregated value
* @return aggregated value given two column values
*/
Object aggregate(Object value1, Object value2);
Object aggregate(Object value1, Object value2, Map<String, String> functionParameters);
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@ public class SegmentProcessorConfig {
private final List<PartitionerConfig> _partitionerConfigs;
private final MergeType _mergeType;
private final Map<String, AggregationFunctionType> _aggregationTypes;
private final Map<String, Map<String, String>> _aggregationFunctionParameters;
private final SegmentConfig _segmentConfig;
private final Consumer<Object> _progressObserver;

private SegmentProcessorConfig(TableConfig tableConfig, Schema schema, TimeHandlerConfig timeHandlerConfig,
List<PartitionerConfig> partitionerConfigs, MergeType mergeType,
Map<String, AggregationFunctionType> aggregationTypes, SegmentConfig segmentConfig,
Map<String, AggregationFunctionType> aggregationTypes,
Map<String, Map<String, String>> aggregationFunctionParameters, SegmentConfig segmentConfig,
Consumer<Object> progressObserver) {
TimestampIndexUtils.applyTimestampIndex(tableConfig, schema);
_tableConfig = tableConfig;
Expand All @@ -58,6 +60,7 @@ private SegmentProcessorConfig(TableConfig tableConfig, Schema schema, TimeHandl
_partitionerConfigs = partitionerConfigs;
_mergeType = mergeType;
_aggregationTypes = aggregationTypes;
_aggregationFunctionParameters = aggregationFunctionParameters;
_segmentConfig = segmentConfig;
_progressObserver = (progressObserver != null) ? progressObserver : p -> {
// Do nothing.
Expand Down Expand Up @@ -106,6 +109,13 @@ public Map<String, AggregationFunctionType> getAggregationTypes() {
return _aggregationTypes;
}

/**
* The aggregation function parameters for the SegmentProcessorFramework's reduce phase with ROLLUP merge type
*/
public Map<String, Map<String, String>> getAggregationFunctionParameters() {
return _aggregationFunctionParameters;
}

/**
* The SegmentConfig for the SegmentProcessorFramework's reduce phase
*/
Expand Down Expand Up @@ -134,6 +144,7 @@ public static class Builder {
private List<PartitionerConfig> _partitionerConfigs;
private MergeType _mergeType;
private Map<String, AggregationFunctionType> _aggregationTypes;
private Map<String, Map<String, String>> _aggregationFunctionParameters;
private SegmentConfig _segmentConfig;
private Consumer<Object> _progressObserver;

Expand Down Expand Up @@ -167,6 +178,11 @@ public Builder setAggregationTypes(Map<String, AggregationFunctionType> aggregat
return this;
}

public Builder setAggregationFunctionParameters(Map<String, Map<String, String>> aggregationFunctionParameters) {
_aggregationFunctionParameters = aggregationFunctionParameters;
return this;
}

public Builder setSegmentConfig(SegmentConfig segmentConfig) {
_segmentConfig = segmentConfig;
return this;
Expand All @@ -193,11 +209,14 @@ public SegmentProcessorConfig build() {
if (_aggregationTypes == null) {
_aggregationTypes = Collections.emptyMap();
}
if (_aggregationFunctionParameters == null) {
_aggregationFunctionParameters = Collections.emptyMap();
}
if (_segmentConfig == null) {
_segmentConfig = new SegmentConfig.Builder().build();
}
return new SegmentProcessorConfig(_tableConfig, _schema, _timeHandlerConfig, _partitionerConfigs, _mergeType,
_aggregationTypes, _segmentConfig, _progressObserver);
_aggregationTypes, _aggregationFunctionParameters, _segmentConfig, _progressObserver);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ public static Reducer getReducer(String partitionId, GenericRowFileManager fileM
case CONCAT:
return new ConcatReducer(fileManager);
case ROLLUP:
return new RollupReducer(partitionId, fileManager, processorConfig.getAggregationTypes(), reducerOutputDir);
return new RollupReducer(partitionId, fileManager, processorConfig.getAggregationTypes(),
processorConfig.getAggregationFunctionParameters(), reducerOutputDir);
case DEDUP:
return new DedupReducer(partitionId, fileManager, reducerOutputDir);
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
Expand Down Expand Up @@ -47,14 +48,17 @@ public class RollupReducer implements Reducer {
private final String _partitionId;
private final GenericRowFileManager _fileManager;
private final Map<String, AggregationFunctionType> _aggregationTypes;
private final Map<String, Map<String, String>> _aggregationFunctionParameters;
private final File _reducerOutputDir;
private GenericRowFileManager _rollupFileManager;

public RollupReducer(String partitionId, GenericRowFileManager fileManager,
Map<String, AggregationFunctionType> aggregationTypes, File reducerOutputDir) {
Map<String, AggregationFunctionType> aggregationTypes,
Map<String, Map<String, String>> aggregationFunctionParameters, File reducerOutputDir) {
_partitionId = partitionId;
_fileManager = fileManager;
_aggregationTypes = aggregationTypes;
_aggregationFunctionParameters = aggregationFunctionParameters;
_reducerOutputDir = reducerOutputDir;
}

Expand Down Expand Up @@ -91,7 +95,8 @@ private GenericRowFileManager doReduce()
for (FieldSpec fieldSpec : fieldSpecs) {
if (fieldSpec.getFieldType() == FieldType.METRIC) {
aggregatorContextList.add(new AggregatorContext(fieldSpec,
_aggregationTypes.getOrDefault(fieldSpec.getName(), DEFAULT_AGGREGATOR_TYPE)));
_aggregationTypes.getOrDefault(fieldSpec.getName(), DEFAULT_AGGREGATOR_TYPE),
_aggregationFunctionParameters.getOrDefault(fieldSpec.getName(), Collections.emptyMap())));
}
}

Expand Down Expand Up @@ -159,7 +164,8 @@ private static void aggregateWithNullFields(GenericRow aggregatedRow, GenericRow
} else {
// Non-null field, aggregate the value
aggregatedRow.putValue(column,
aggregatorContext._aggregator.aggregate(aggregatedRow.getValue(column), rowToAggregate.getValue(column)));
aggregatorContext._aggregator.aggregate(aggregatedRow.getValue(column), rowToAggregate.getValue(column),
aggregatorContext._functionParameters));
}
}
}
Expand All @@ -169,17 +175,21 @@ private static void aggregateWithoutNullFields(GenericRow aggregatedRow, Generic
for (AggregatorContext aggregatorContext : aggregatorContextList) {
String column = aggregatorContext._column;
aggregatedRow.putValue(column,
aggregatorContext._aggregator.aggregate(aggregatedRow.getValue(column), rowToAggregate.getValue(column)));
aggregatorContext._aggregator.aggregate(aggregatedRow.getValue(column), rowToAggregate.getValue(column),
aggregatorContext._functionParameters));
}
}

private static class AggregatorContext {
final String _column;
final ValueAggregator _aggregator;
final Map<String, String> _functionParameters;

AggregatorContext(FieldSpec fieldSpec, AggregationFunctionType aggregationType) {
AggregatorContext(FieldSpec fieldSpec, AggregationFunctionType aggregationType,
Map<String, String> functionParameters) {
_column = fieldSpec.getName();
_aggregator = ValueAggregatorFactory.getValueAggregator(aggregationType, fieldSpec.getDataType());
_functionParameters = functionParameters;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,28 @@ public static Map<String, AggregationFunctionType> getAggregationTypes(Map<Strin
return aggregationTypes;
}

/**
* Returns a map from column name to the aggregation function parameters associated with it based on the task config.
*/
public static Map<String, Map<String, String>> getAggregationFunctionParameters(Map<String, String> taskConfig) {
Map<String, Map<String, String>> aggregationFunctionParameters = new HashMap<>();
String prefix = MergeTask.AGGREGATION_FUNCTION_PARAMETERS_PREFIX;

for (Map.Entry<String, String> entry : taskConfig.entrySet()) {
String key = entry.getKey();
String value = entry.getValue();
if (key.startsWith(prefix)) {
String[] parts = key.substring(prefix.length()).split("\\.", 2);
if (parts.length == 2) {
String metricColumn = parts[0];
String paramName = parts[1];
aggregationFunctionParameters.computeIfAbsent(metricColumn, k -> new HashMap<>()).put(paramName, value);
}
}
}
return aggregationFunctionParameters;
}

/**
* Returns the segment config based on the task config.
*/
Expand Down
Loading

0 comments on commit 442c0fc

Please sign in to comment.