From c1e7abedd818dfd0a1984f38811e73a0879bc170 Mon Sep 17 00:00:00 2001
From: Jiaqi Zhang <jiaqizhang@meta.com>
Date: Tue, 17 Dec 2024 16:17:52 -0800
Subject: [PATCH] feat: Separate null count and minmax from column stats
 (#11860)

Summary:

PR to address #11741

- Removed the use of columnHasNulls in RowContainer and replaced them with row stats
- Separate null count/sumBytes from minmax. In the case of rows erasure, only min/max is invalidated.

Differential Revision: D67229925
---
 velox/exec/HashProbe.cpp              |  4 +-
 velox/exec/RowContainer.cpp           | 49 ++++++++++++++++----
 velox/exec/RowContainer.h             | 67 +++++++++++++++++++++------
 velox/exec/tests/RowContainerTest.cpp | 48 ++++++++++++++++---
 4 files changed, 136 insertions(+), 32 deletions(-)

diff --git a/velox/exec/HashProbe.cpp b/velox/exec/HashProbe.cpp
index a788cc8e2ae5..8fe132dc0dbb 100644
--- a/velox/exec/HashProbe.cpp
+++ b/velox/exec/HashProbe.cpp
@@ -302,8 +302,8 @@ std::optional<uint64_t> HashProbe::estimatedRowSize(
   std::vector<RowColumn::Stats> varSizeListColumnsStats;
   varSizeListColumnsStats.reserve(varSizedColumns.size());
   for (uint32_t i = 0; i < varSizedColumns.size(); ++i) {
-    auto statsOpt = columnStats(varSizedColumns[i]);
-    if (!statsOpt.has_value()) {
+    const auto statsOpt = columnStats(varSizedColumns[i]);
+    if (!statsOpt.has_value() || !statsOpt->minMaxColumnStatsValid()) {
       return std::nullopt;
     }
     varSizeListColumnsStats.push_back(statsOpt.value());
diff --git a/velox/exec/RowContainer.cpp b/velox/exec/RowContainer.cpp
index 74d49266b0fd..d1337c17455c 100644
--- a/velox/exec/RowContainer.cpp
+++ b/velox/exec/RowContainer.cpp
@@ -188,7 +188,6 @@ RowContainer::RowContainer(
     if (nullableKeys_) {
       ++nullOffset;
     }
-    columnHasNulls_.push_back(false);
   }
   // Make offset at least sizeof pointer so that there is space for a
   // free list next pointer below the bit at 'freeFlagOffset_'.
@@ -217,7 +216,6 @@ RowContainer::RowContainer(
     nullOffsets_.push_back(nullOffset);
     ++nullOffset;
     isVariableWidth |= !type->isFixedWidth();
-    columnHasNulls_.push_back(false);
   }
   if (hasProbedFlag) {
     nullOffsets_.push_back(nullOffset);
@@ -336,16 +334,34 @@ char* RowContainer::initializeRow(char* row, bool reuse) {
   return row;
 }
 
+void RowContainer::removeRowFromColumnStats(
+    const char* row,
+    bool keepNullCount) {
+  // Update row column stats accordingly
+  for (auto i = 0; i < types_.size(); i++) {
+    if (isNullAt(row, columnAt(i))) {
+      rowColumnsStats_[i].removeCellSize(0, true, keepNullCount);
+    } else if (types_[i]->isFixedWidth()) {
+      rowColumnsStats_[i].removeCellSize(fixedSizeAt(i), false, keepNullCount);
+    } else {
+      rowColumnsStats_[i].removeCellSize(
+          variableSizeAt(row, i), false, keepNullCount);
+    }
+  }
+  invalidateMinMaxColumnStats();
+}
+
 void RowContainer::eraseRows(folly::Range<char**> rows) {
   freeRowsExtraMemory(rows, /*freeNextRowVector=*/true);
   for (auto* row : rows) {
     VELOX_CHECK(!bits::isBitSet(row, freeFlagOffset_), "Double free of row");
+    removeRowFromColumnStats(row, false);
+
     bits::setBit(row, freeFlagOffset_);
     nextFree(row) = firstFreeRow_;
     firstFreeRow_ = row;
   }
   numFreeRows_ += rows.size();
-  invalidateColumnStats();
 }
 
 int32_t RowContainer::findRows(folly::Range<char**> rows, char** result) const {
@@ -466,11 +482,26 @@ void RowContainer::freeRowsExtraMemory(
   numRows_ -= rows.size();
 }
 
-void RowContainer::invalidateColumnStats() {
-  if (rowColumnsStats_.empty()) {
-    return;
+void RowColumn::Stats::removeCellSize(
+    int32_t bytes,
+    bool isNull,
+    bool keepNullCnt) {
+  // we only update nullCount, nonNullCount, and numBytes
+  // when the cell is removed. Because min/max need the
+  // full column data and not recorded in stats.
+  if (isNull) {
+    VELOX_DCHECK_EQ(bytes, 0);
+    if (!keepNullCnt) {
+      --nullCount_;
+    }
+  } else {
+    --nonNullCount_;
+    sumBytes_ -= bytes;
+    if (keepNullCnt) {
+      ++nullCount_;
+    }
   }
-  rowColumnsStats_.clear();
+  invalidateMinMaxColumnStats();
 }
 
 // static
@@ -816,7 +847,6 @@ void RowContainer::storeComplexType(
   if (decoded.isNullAt(index)) {
     VELOX_DCHECK(nullMask);
     row[nullByte] |= nullMask;
-    updateColumnHasNulls(column, true);
     return;
   }
   RowSizeTracker tracker(row[rowSizeOffset_], *stringAllocator_);
@@ -989,6 +1019,9 @@ void RowContainer::clear() {
   normalizedKeySize_ = originalNormalizedKeySize_;
   numFreeRows_ = 0;
   firstFreeRow_ = nullptr;
+
+  rowColumnsStats_.clear();
+  rowColumnsStats_.resize(types_.size());
 }
 
 void RowContainer::setProbedFlag(char** rows, int32_t numRows) {
diff --git a/velox/exec/RowContainer.h b/velox/exec/RowContainer.h
index 72cd08e276dd..3260aa6f465b 100644
--- a/velox/exec/RowContainer.h
+++ b/velox/exec/RowContainer.h
@@ -187,6 +187,8 @@ class RowColumn {
       ++nullCount_;
     }
 
+    void removeCellSize(int32_t bytes, bool isNull, bool keepNullCnt);
+
     int32_t maxBytes() const {
       return maxBytes_;
     }
@@ -218,6 +220,23 @@ class RowColumn {
       return nullCount_ + nonNullCount_;
     }
 
+    void invalidateMinMaxColumnStats() {
+      minMaxStatsValid_ = false;
+    }
+
+    bool minMaxColumnStatsValid() const {
+      return minMaxStatsValid_;
+    }
+
+    void reset() {
+      minBytes_ = 0;
+      maxBytes_ = 0;
+      sumBytes_ = 0;
+      nonNullCount_ = 0;
+      nullCount_ = 0;
+      minMaxStatsValid_ = true;
+    }
+
     /// Merges multiple aggregated stats of the same column into a single one.
     static Stats merge(const std::vector<Stats>& statsList);
 
@@ -225,6 +244,7 @@ class RowColumn {
     // Aggregated stats for non-null rows of the column.
     int32_t minBytes_{0};
     int32_t maxBytes_{0};
+    bool minMaxStatsValid_{true};
     uint64_t sumBytes_{0};
 
     uint32_t nonNullCount_{0};
@@ -316,15 +336,26 @@ class RowContainer {
              : 0);
   }
 
+  /// Removes the data of a given row from the column stats by iterating
+  /// over each column in the row and updates the column stats.
+  ///
+  /// @param row - The row from which the column stats are to be updated.
+  /// @param keepNullCount - If true, the null count is kept if originally null
+  /// or increase if originally not null.
+  void removeRowFromColumnStats(const char* row, bool keepNullCount);
+
   /// Sets all fields, aggregates, keys and dependents to null. Used when making
   /// a row with uninitialized keys for aggregates with no-op partial
   /// aggregation.
   void setAllNull(char* row) {
+    // we need to remove the row from the column stats before we free the row
+    // otherwise the cell will be incorrectly recognized as null, causing stats
+    // to be incorrect.
+    removeRowFromColumnStats(row, true);
     if (!nullOffsets_.empty()) {
       memset(row + nullByte(nullOffsets_[0]), 0xff, initialNulls_.size());
       bits::clearBit(row, freeFlagOffset_);
     }
-    invalidateColumnStats();
   }
 
   /// The row size excluding any out-of-line stored variable length values.
@@ -817,13 +848,25 @@ class RowContainer {
   /// invalidated. Any row erase operations will invalidate column stats.
   std::optional<RowColumn::Stats> columnStats(int32_t columnIndex) const;
 
+  uint32_t columnNullCount(int32_t columnIndex) {
+    return rowColumnsStats_[columnIndex].nullCount();
+  }
+
+  bool rowColumnsStatsMinMaxValid() const {
+    return std::all_of(
+        rowColumnsStats_.begin(),
+        rowColumnsStats_.end(),
+        [](const auto& stats) { return stats.minMaxColumnStatsValid(); });
+  }
+
   const auto& keyTypes() const {
     return keyTypes_;
   }
 
-  /// Returns true if specified column may have nulls, false otherwise.
+  /// Returns true if specified column has nulls, false otherwise.
   inline bool columnHasNulls(int32_t columnIndex) const {
-    return columnHasNulls_[columnIndex];
+    return columnStats(columnIndex).has_value() &&
+        columnStats(columnIndex)->nullCount() > 0;
   }
 
   const std::vector<Accumulator>& accumulators() const {
@@ -1015,7 +1058,6 @@ class RowContainer {
       // Do not leave an uninitialized value in the case of a
       // null. This is an error with valgrind/asan.
       *reinterpret_cast<T*>(row + offset) = T();
-      updateColumnHasNulls(columnIndex, true);
       return;
     }
     if constexpr (std::is_same_v<T, StringView>) {
@@ -1466,14 +1508,12 @@ class RowContainer {
       char* row,
       int32_t columnIndex);
 
-  // Light weight aggregated column stats does not support row erasures. This
+  // Min/max column stats do not support row erasures. This
   // method is called whenever a row is erased.
-  void invalidateColumnStats();
-
-  // Updates the specific column's columnHasNulls_ flag, if 'hasNulls' is true.
-  // columnHasNulls_ flag is false by default.
-  inline void updateColumnHasNulls(int32_t columnIndex, bool hasNulls) {
-    columnHasNulls_[columnIndex] = columnHasNulls_[columnIndex] || hasNulls;
+  void invalidateMinMaxColumnStats() {
+    for (auto columnStats : rowColumnsStats_) {
+      columnStats.invalidateMinMaxColumnStats();
+    }
   }
 
   const std::vector<TypePtr> keyTypes_;
@@ -1484,8 +1524,6 @@ class RowContainer {
 
   const std::unique_ptr<HashStringAllocator> stringAllocator_;
 
-  std::vector<bool> columnHasNulls_;
-
   // Indicates if we can add new row to this row container. It is set to false
   // after user calls 'getRowPartitions()' to create 'rowPartitions' object for
   // parallel join build.
@@ -1510,7 +1548,7 @@ class RowContainer {
   // Offset and null indicator offset of non-aggregate fields as a single word.
   // Corresponds pairwise to 'types_'.
   std::vector<RowColumn> rowColumns_;
-  // Optional aggregated column stats(e.g. min/max size) for non-aggregate
+  // Aggregated column stats(e.g. min/max size) for non-aggregate
   // fields. Index aligns with 'rowColumns_'.
   std::vector<RowColumn::Stats> rowColumnsStats_;
   // Bit offset of the probed flag for a full or right outer join  payload. 0 if
@@ -1640,7 +1678,6 @@ inline void RowContainer::storeWithNulls<TypeKind::HUGEINT>(
   if (decoded.isNullAt(rowIndex)) {
     row[nullByte] |= nullMask;
     memset(row + offset, 0, sizeof(int128_t));
-    updateColumnHasNulls(columnIndex, true);
     return;
   }
   HugeInt::serialize(decoded.valueAt<int128_t>(rowIndex), row + offset);
diff --git a/velox/exec/tests/RowContainerTest.cpp b/velox/exec/tests/RowContainerTest.cpp
index a2040721f664..0fdc27ada6ae 100644
--- a/velox/exec/tests/RowContainerTest.cpp
+++ b/velox/exec/tests/RowContainerTest.cpp
@@ -90,8 +90,10 @@ class RowContainerTestHelper {
         if (rowContainer_->types_[i]->isFixedWidth()) {
           continue;
         }
-        VELOX_CHECK_EQ(expectedStats.maxBytes(), storedStats.maxBytes());
-        VELOX_CHECK_EQ(expectedStats.minBytes(), storedStats.minBytes());
+        if (storedStats.minMaxColumnStatsValid()) {
+          VELOX_CHECK_EQ(expectedStats.maxBytes(), storedStats.maxBytes());
+          VELOX_CHECK_EQ(expectedStats.minBytes(), storedStats.minBytes());
+        }
         VELOX_CHECK_EQ(expectedStats.sumBytes(), storedStats.sumBytes());
         VELOX_CHECK_EQ(expectedStats.avgBytes(), storedStats.avgBytes());
         VELOX_CHECK_EQ(
@@ -2529,11 +2531,11 @@ TEST_F(RowContainerTest, invalidatedColumnStats) {
     invalidateFunc(rowContainer.get(), rows);
     RowContainerTestHelper(rowContainer.get()).checkConsistency();
 
-    ASSERT_FALSE(rowContainer->columnStats(0).has_value());
-    ASSERT_FALSE(rowContainer->columnStats(1).has_value());
-    ASSERT_FALSE(rowContainer->columnStats(2).has_value());
-    ASSERT_FALSE(rowContainer->columnStats(3).has_value());
-    ASSERT_FALSE(rowContainer->columnStats(4).has_value());
+    ASSERT_TRUE(rowContainer->columnStats(0).has_value());
+    ASSERT_TRUE(rowContainer->columnStats(1).has_value());
+    ASSERT_TRUE(rowContainer->columnStats(2).has_value());
+    ASSERT_TRUE(rowContainer->columnStats(3).has_value());
+    ASSERT_TRUE(rowContainer->columnStats(4).has_value());
   }
 }
 
@@ -2596,6 +2598,22 @@ TEST_F(RowContainerTest, rowColumnStats) {
   EXPECT_EQ(stats.nonNullCount(), 6);
   EXPECT_EQ(stats.nullCount(), 4);
   EXPECT_EQ(stats.numCells(), 10);
+
+  stats.removeCellSize(25, false, false);
+  EXPECT_EQ(stats.minMaxColumnStatsValid(), false);
+  EXPECT_EQ(stats.sumBytes(), 60);
+  EXPECT_EQ(stats.avgBytes(), 12);
+  EXPECT_EQ(stats.numCells(), 9);
+  EXPECT_EQ(stats.nonNullCount(), 5);
+  EXPECT_EQ(stats.nullCount(), 4);
+
+  stats.removeCellSize(0, true, false);
+  EXPECT_EQ(stats.minMaxColumnStatsValid(), false);
+  EXPECT_EQ(stats.sumBytes(), 60);
+  EXPECT_EQ(stats.avgBytes(), 12);
+  EXPECT_EQ(stats.numCells(), 8);
+  EXPECT_EQ(stats.nonNullCount(), 5);
+  EXPECT_EQ(stats.nullCount(), 3);
 }
 
 TEST_F(RowContainerTest, storeAndCollectColumnStats) {
@@ -2636,5 +2654,21 @@ TEST_F(RowContainerTest, storeAndCollectColumnStats) {
       EXPECT_EQ(stats.avgBytes(), 13);
     }
   }
+
+  rowContainer->eraseRows(folly::Range(rows.data(), 10)); // there are 2 nulls
+  EXPECT_EQ(rowContainer->rowColumnsStatsMinMaxValid(), false);
+  for (int i = 0; i < rowContainer->columnTypes().size(); ++i) {
+    const auto stats = rowContainer->columnStats(i).value();
+    EXPECT_EQ(stats.nonNullCount(), 849);
+    EXPECT_EQ(stats.nullCount(), 141);
+    EXPECT_EQ(stats.numCells(), kNumRows - 10);
+    if (rowVector->childAt(i)->typeKind() == TypeKind::VARCHAR) {
+      EXPECT_EQ(stats.sumBytes(), 11809);
+      EXPECT_EQ(stats.avgBytes(), 13);
+    }
+  }
+  rowContainer->clear();
+  EXPECT_TRUE(rowContainer->rowColumnsStatsMinMaxValid());
 }
+
 } // namespace facebook::velox::exec::test