diff --git a/coral-incremental/src/main/java/com/linkedin/coral/incremental/RelNodeCostEstimator.java b/coral-incremental/src/main/java/com/linkedin/coral/incremental/RelNodeCostEstimator.java new file mode 100644 index 000000000..c9a17a26a --- /dev/null +++ b/coral-incremental/src/main/java/com/linkedin/coral/incremental/RelNodeCostEstimator.java @@ -0,0 +1,286 @@ +/** + * Copyright 2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.incremental; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; + +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.TableScan; +import org.apache.calcite.rel.logical.LogicalJoin; +import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.logical.LogicalUnion; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; + +import static java.lang.Math.*; + + +/** + * RelNodeCostEstimator is a utility class designed to estimate the cost of executing relational operations + * in a query plan. It uses statistical information about table row counts and column distinct values + * to compute costs associated with different types of relational operations like table scans, joins, + * unions, and projections. + * + *

This class supports loading statistics from a JSON configuration file. + * For a relational operations (RelNode), the execution cost and row count are estimated based on + * these statistics and the input relational expressions. + * + *

The cost estimation takes into account factors such as I/O costs and data shuffling costs. + * The cost weight of writing a row to disk is IOCostValue, and the cost weight of execution is executionCostValue. + * + *

Cost is get from 'getCost' method, which returns the total cost of the query plan, and cost consists of + * execution cost and I/O cost. + */ +public class RelNodeCostEstimator { + + class CostInfo { + // TODO: we may also need to add TableName field. + Double executionCost; + Double outputSize; + + public CostInfo(Double executionCost, Double row) { + this.executionCost = executionCost; + this.outputSize = row; + } + } + + class TableStatistic { + // The number of rows in the table + Double rowCount; + // The number of distinct values in each column + // This doesn't work for nested columns and complex types + Map distinctCountByRow; + + public TableStatistic() { + this.distinctCountByRow = new HashMap<>(); + } + } + + class JoinKey { + String leftTableName; + String rightTableName; + String leftFieldName; + String rightFieldName; + + public JoinKey(String leftTableName, String rightTableName, String leftFieldName, String rightFieldName) { + this.leftTableName = leftTableName; + this.rightTableName = rightTableName; + this.leftFieldName = leftFieldName; + this.rightFieldName = rightFieldName; + } + } + + private Map costStatistic = new HashMap<>(); + + private final Double IOCostValue; + + private final Double executionCostValue; + + public RelNodeCostEstimator(Double IOCostValue, Double executionCostValue) { + this.IOCostValue = IOCostValue; + this.executionCostValue = executionCostValue; + } + + /** + * Loads statistics from a JSON configuration file and stores them in internal data structures. + * + *

This method reads a JSON file from the specified path, parses its content, and extracts + * statistical information. For each table in the JSON object, it retrieves the row count and + * distinct counts for each column. These values are then stored in the `stat` and `distinctStat` + * maps, respectively. + * + * @param configPath the path to the JSON configuration file + */ + public void loadStatistic(String configPath) throws IOException { + try { + String content = new String(Files.readAllBytes(Paths.get(configPath))); + JsonObject jsonObject = new JsonParser().parse(content).getAsJsonObject(); + for (Map.Entry entry : jsonObject.entrySet()) { + TableStatistic tableStatistic = new TableStatistic(); + String tableName = entry.getKey(); + JsonObject tableObject = entry.getValue().getAsJsonObject(); + + Double rowCount = tableObject.get("RowCount").getAsDouble(); + + JsonObject distinctCounts = tableObject.getAsJsonObject("DistinctCounts"); + + tableStatistic.rowCount = rowCount; + + for (Map.Entry distinctEntry : distinctCounts.entrySet()) { + String columnName = distinctEntry.getKey(); + Double distinctCount = distinctEntry.getValue().getAsDouble(); + + tableStatistic.distinctCountByRow.put(columnName, distinctCount); + } + costStatistic.put(tableName, tableStatistic); + + } + } catch (IOException e) { + throw new IOException("Failed to load statistics from the configuration file: " + configPath, e); + } + + } + + /** + * Returns the total cost of executing a relational operation. + * + *

This method computes the cost of executing a relational operation based on the input + * relational expression. The cost is calculated as the sum of the execution cost and the I/O cost. + * We assume that I/O only occurs at the root of the query plan (Project) where we write the output to disk. + * So the cost is the sum of the execution cost of all children RelNodes and IOCostValue * outputSize of the root Project RelNode. + * + * @param rel the input relational expression + * @return the total cost of executing the relational operation + */ + public Double getCost(RelNode rel) { + CostInfo executionCostInfo = getExecutionCost(rel); + Double writeCost = executionCostInfo.outputSize * IOCostValue; + return executionCostInfo.executionCost * executionCostValue + writeCost; + } + + private CostInfo getExecutionCost(RelNode rel) { + if (rel instanceof TableScan) { + return getExecutionCostTableScan((TableScan) rel); + } else if (rel instanceof LogicalJoin) { + return getExecutionCostJoin((LogicalJoin) rel); + } else if (rel instanceof LogicalUnion) { + return getExecutionCostUnion((LogicalUnion) rel); + } else if (rel instanceof LogicalProject) { + return getExecutionCostProject((LogicalProject) rel); + } + throw new IllegalArgumentException("Unsupported relational operation: " + rel.getClass().getSimpleName()); + } + + private CostInfo getExecutionCostTableScan(TableScan scan) { + RelOptTable originalTable = scan.getTable(); + String tableName = getTableName(originalTable); + try { + TableStatistic tableStat = costStatistic.get(tableName); + Double rowCount = tableStat.rowCount; + return new CostInfo(rowCount, rowCount); + } catch (NullPointerException e) { + throw new IllegalArgumentException("Table statistics not found for table: " + tableName); + } + } + + private String getTableName(RelOptTable table) { + return String.join(".", table.getQualifiedName()); + } + + private CostInfo getExecutionCostJoin(LogicalJoin join) { + RelNode left = join.getLeft(); + RelNode right = join.getRight(); + CostInfo leftCost = getExecutionCost(left); + CostInfo rightCost = getExecutionCost(right); + Double joinSize = estimateJoinSize(join, leftCost.outputSize, rightCost.outputSize); + // The execution cost of a join is the maximum execution cost of its children because the execution cost of a single RelNode + // is mainly determined by the cost of the shuffle operation. + // And in modern distributed systems, the shuffle cost is dominated by the largest shuffle. + return new CostInfo(max(leftCost.executionCost, rightCost.executionCost), joinSize); + } + + private List getJoinKeys(LogicalJoin join) { + List joinKeys = new ArrayList<>(); + RexNode condition = join.getCondition(); + if (condition instanceof RexCall) { + getJoinKeysFromJoinCondition((RexCall) condition, join, joinKeys); + } + // Assertion to check if joinKeys.size() is greater than or equal to 1 + if (joinKeys.size() < 1) { + throw new IllegalArgumentException("Join keys size is less than 1"); + } + return joinKeys; + } + + private void getJoinKeysFromJoinCondition(RexCall call, LogicalJoin join, List joinKeys) { + if (call.getOperator().getName().equalsIgnoreCase("AND")) { + // Process each operand of the AND separately + for (RexNode operand : call.getOperands()) { + if (operand instanceof RexCall) { + getJoinKeysFromJoinCondition((RexCall) operand, join, joinKeys); + } + } + } else { + // Process the join condition (e.g., EQUALS) + List operands = call.getOperands(); + if (operands.size() == 2 && operands.get(0) instanceof RexInputRef && operands.get(1) instanceof RexInputRef) { + RexInputRef leftRef = (RexInputRef) operands.get(0); + RexInputRef rightRef = (RexInputRef) operands.get(1); + RelDataType leftType = join.getLeft().getRowType(); + RelDataType rightType = join.getRight().getRowType(); + + int leftIndex = leftRef.getIndex(); + int rightIndex = rightRef.getIndex() - leftType.getFieldCount(); + + RelDataTypeField leftField = leftType.getFieldList().get(leftIndex); + String leftTableName = getTableName(join.getLeft().getTable()); + String leftFieldName = leftField.getName(); + RelDataTypeField rightField = rightType.getFieldList().get(rightIndex); + String rightTableName = getTableName(join.getRight().getTable()); + String rightFieldName = rightField.getName(); + + joinKeys.add(new JoinKey(leftTableName, rightTableName, leftFieldName, rightFieldName)); + } + } + } + + private Double estimateJoinSize(LogicalJoin join, Double leftSize, Double rightSize) { + List joinKeys = getJoinKeys(join); + Double selectivity = 1.0; + for (JoinKey joinKey : joinKeys) { + String leftTableName = joinKey.leftTableName; + String rightTableName = joinKey.rightTableName; + String leftFieldName = joinKey.leftFieldName; + String rightFieldName = joinKey.rightFieldName; + try { + TableStatistic leftTableStat = costStatistic.get(leftTableName); + TableStatistic rightTableStat = costStatistic.get(rightTableName); + Double leftCardinality = leftTableStat.rowCount; + Double rightCardinality = rightTableStat.rowCount; + Double leftDistinct = leftTableStat.distinctCountByRow.getOrDefault(leftFieldName, leftCardinality); + Double rightDistinct = rightTableStat.distinctCountByRow.getOrDefault(rightFieldName, rightCardinality); + selectivity *= 1 / max(leftDistinct, rightDistinct); + } catch (NullPointerException e) { + throw new IllegalArgumentException( + "Table statistics not found for table: " + leftTableName + " or " + rightTableName); + } + } + return leftSize * rightSize * selectivity; + } + + private CostInfo getExecutionCostUnion(LogicalUnion union) { + Double unionCost = 0.0; + Double unionSize = 0.0; + RelNode input; + for (Iterator var4 = union.getInputs().iterator(); var4.hasNext();) { + input = (RelNode) var4.next(); + CostInfo inputCost = getExecutionCost(input); + unionSize += inputCost.outputSize; + unionCost = max(inputCost.executionCost, unionCost); + } + unionCost *= 2; + return new CostInfo(unionCost, unionSize); + } + + private CostInfo getExecutionCostProject(LogicalProject project) { + return getExecutionCost(project.getInput()); + } +} diff --git a/coral-incremental/src/test/java/com/linkedin/coral/incremental/RelNodeCostEstimatorTest.java b/coral-incremental/src/test/java/com/linkedin/coral/incremental/RelNodeCostEstimatorTest.java new file mode 100644 index 000000000..2aae6c8ac --- /dev/null +++ b/coral-incremental/src/test/java/com/linkedin/coral/incremental/RelNodeCostEstimatorTest.java @@ -0,0 +1,106 @@ +/** + * Copyright 2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.incremental; + +import java.io.File; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import org.apache.calcite.rel.RelNode; +import org.apache.commons.io.FileUtils; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.api.MetaException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static com.linkedin.coral.incremental.TestUtils.*; +import static org.testng.Assert.*; + + +public class RelNodeCostEstimatorTest { + private HiveConf conf; + + private RelNodeCostEstimator estimator; + + static final String TEST_JSON_FILE_DIR = "src/test/resources/"; + + @BeforeClass + public void beforeClass() throws HiveException, MetaException, IOException { + conf = TestUtils.loadResourceHiveConf(); + estimator = new RelNodeCostEstimator(2.0, 1.0); + TestUtils.initializeViews(conf); + } + + @AfterTest + public void afterClass() throws IOException { + FileUtils.deleteDirectory(new File(conf.get(CORAL_INCREMENTAL_TEST_DIR))); + } + + public Map fakeStatData() { + Map stat = new HashMap<>(); + stat.put("hive.test.bar1", 80.0); + stat.put("hive.test.bar2", 20.0); + stat.put("hive.test.bar1_prev", 40.0); + stat.put("hive.test.bar2_prev", 10.0); + stat.put("hive.test.bar1_delta", 60.0); + stat.put("hive.test.bar2_delta", 10.0); + return stat; + } + + @Test + public void testSimpleSelectAll() throws IOException { + String sql = "SELECT * FROM test.bar1"; + RelNode relNode = hiveToRelConverter.convertSql(sql); + estimator.loadStatistic(TEST_JSON_FILE_DIR + "statistic.json"); + assertEquals(estimator.getCost(relNode), 300.0); + } + + @Test + public void testSimpleJoin() throws IOException { + String sql = "SELECT * FROM test.bar1 JOIN test.bar2 ON test.bar1.x = test.bar2.x"; + RelNode relNode = hiveToRelConverter.convertSql(sql); + estimator.loadStatistic(TEST_JSON_FILE_DIR + "statistic.json"); + assertEquals(estimator.getCost(relNode), 500.0); + } + + @Test + public void testSimpleUnion() throws IOException { + String sql = "SELECT *\n" + "FROM test.bar1 AS bar1\n" + "INNER JOIN test.bar2 AS bar2 ON bar1.x = bar2.x\n" + + "UNION ALL\n" + "SELECT *\n" + "FROM test.bar3 AS bar3\n" + "INNER JOIN test.bar2 AS bar2 ON bar3.x = bar2.x"; + RelNode relNode = hiveToRelConverter.convertSql(sql); + estimator.loadStatistic(TEST_JSON_FILE_DIR + "statistic.json"); + assertEquals(estimator.getCost(relNode), 680.0); + } + + @Test + public void testUnsupportOperator() throws IOException { + String sql = "SELECT * FROM test.bar1 WHERE x = 1"; + RelNode relNode = hiveToRelConverter.convertSql(sql); + estimator.loadStatistic(TEST_JSON_FILE_DIR + "statistic.json"); + try { + estimator.getCost(relNode); + fail("Should throw exception"); + } catch (RuntimeException e) { + assertEquals(e.getMessage(), "Unsupported relational operation: " + "LogicalFilter"); + } + } + + @Test + public void testNoStatistic() throws IOException { + String sql = "SELECT * FROM test.foo"; + RelNode relNode = hiveToRelConverter.convertSql(sql); + estimator.loadStatistic(TEST_JSON_FILE_DIR + "statistic.json"); + try { + estimator.getCost(relNode); + fail("Should throw exception"); + } catch (RuntimeException e) { + assertEquals(e.getMessage(), "Table statistics not found for table: " + "hive.test.foo"); + } + } +} diff --git a/coral-incremental/src/test/resources/statistic.json b/coral-incremental/src/test/resources/statistic.json new file mode 100644 index 000000000..0b75555e9 --- /dev/null +++ b/coral-incremental/src/test/resources/statistic.json @@ -0,0 +1,20 @@ +{ + "hive.test.bar1": { + "RowCount": 100, + "DistinctCounts": { + "x": 10 + } + }, + "hive.test.bar2": { + "RowCount": 20, + "DistinctCounts": { + "x": 5 + } + }, + "hive.test.bar3": { + "RowCount": 50, + "DistinctCounts": { + "x": 25 + } + } +} \ No newline at end of file