diff --git a/cpp/src/arrow/acero/asof_join_node.cc b/cpp/src/arrow/acero/asof_join_node.cc index a22d6ce9d4a80..4f4f9d65a2d09 100644 --- a/cpp/src/arrow/acero/asof_join_node.cc +++ b/cpp/src/arrow/acero/asof_join_node.cc @@ -566,6 +566,7 @@ class InputState { return queue_.UnsyncFront(); } +// TODO(jeraguilon): consolidate #define LATEST_VAL_CASE(id, val) \ case Type::id: { \ using T = typename TypeIdTraits::Type; \ diff --git a/cpp/src/arrow/acero/backpressure_handler.h b/cpp/src/arrow/acero/backpressure_handler.h index ed44e531e2de5..7fa09745ad561 100644 --- a/cpp/src/arrow/acero/backpressure_handler.h +++ b/cpp/src/arrow/acero/backpressure_handler.h @@ -1,3 +1,7 @@ +<<<<<<< HEAD +======= + +>>>>>>> b34c999b6 (Create sorted merge node) // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information @@ -16,6 +20,10 @@ // under the License. #pragma once +<<<<<<< HEAD +======= +#include "arrow/acero/exec_plan.h" +>>>>>>> b34c999b6 (Create sorted merge node) #include "arrow/acero/options.h" #include @@ -24,15 +32,26 @@ namespace arrow::acero { class BackpressureHandler { private: +<<<<<<< HEAD BackpressureHandler(size_t low_threshold, size_t high_threshold, std::unique_ptr backpressure_control) : low_threshold_(low_threshold), +======= + BackpressureHandler(ExecNode* input, size_t low_threshold, size_t high_threshold, + std::unique_ptr backpressure_control) + : input_(input), + low_threshold_(low_threshold), +>>>>>>> b34c999b6 (Create sorted merge node) high_threshold_(high_threshold), backpressure_control_(std::move(backpressure_control)) {} public: static Result Make( +<<<<<<< HEAD size_t low_threshold, size_t high_threshold, +======= + ExecNode* input, size_t low_threshold, size_t high_threshold, +>>>>>>> b34c999b6 (Create sorted merge node) std::unique_ptr backpressure_control) { if (low_threshold >= high_threshold) { return Status::Invalid("low threshold (", low_threshold, @@ -41,7 +60,11 @@ class BackpressureHandler { if (backpressure_control == NULLPTR) { return Status::Invalid("null backpressure control parameter"); } +<<<<<<< HEAD BackpressureHandler backpressure_handler(low_threshold, high_threshold, +======= + BackpressureHandler backpressure_handler(input, low_threshold, high_threshold, +>>>>>>> b34c999b6 (Create sorted merge node) std::move(backpressure_control)); return std::move(backpressure_handler); } @@ -54,7 +77,20 @@ class BackpressureHandler { } } +<<<<<<< HEAD + private: +======= + Status ForceShutdown() { + // It may be unintuitive to call Resume() here, but this is to avoid a deadlock. + // Since acero's executor won't terminate if any one node is paused, we need to + // force resume the node before stopping production. + backpressure_control_->Resume(); + return input_->StopProducing(); + } + private: + ExecNode* input_; +>>>>>>> b34c999b6 (Create sorted merge node) size_t low_threshold_; size_t high_threshold_; std::unique_ptr backpressure_control_; diff --git a/cpp/src/arrow/acero/concurrent_queue.h b/cpp/src/arrow/acero/concurrent_queue.h index 2ec9caa5856f4..c3bb29eac0718 100644 --- a/cpp/src/arrow/acero/concurrent_queue.h +++ b/cpp/src/arrow/acero/concurrent_queue.h @@ -140,6 +140,11 @@ class BackpressureConcurrentQueue : public ConcurrentQueue { return ConcurrentQueue::TryPopUnlocked(); } +<<<<<<< HEAD +======= + Status ForceShutdown() { return handler_.ForceShutdown(); } + +>>>>>>> b34c999b6 (Create sorted merge node) private: BackpressureHandler handler_; }; diff --git a/cpp/src/arrow/acero/sorted_merge_node.cc b/cpp/src/arrow/acero/sorted_merge_node.cc index 4861c153afd20..c334b29c49380 100644 --- a/cpp/src/arrow/acero/sorted_merge_node.cc +++ b/cpp/src/arrow/acero/sorted_merge_node.cc @@ -27,11 +27,17 @@ #include "arrow/acero/options.h" #include "arrow/acero/query_context.h" #include "arrow/acero/time_series_util.h" +<<<<<<< HEAD #include "arrow/acero/unmaterialized_table.h" #include "arrow/acero/util.h" #include "arrow/array/builder_base.h" #include "arrow/result.h" #include "arrow/type_fwd.h" +======= +#include "arrow/acero/util.h" +#include "arrow/array/builder_base.h" +#include "arrow/result.h" +>>>>>>> b34c999b6 (Create sorted merge node) #include "arrow/util/logging.h" namespace { @@ -42,7 +48,11 @@ struct Defer { ~Defer() noexcept { callable(); } }; +<<<<<<< HEAD std::vector GetInputLabels( +======= +std::vector getInputLabels( +>>>>>>> b34c999b6 (Create sorted merge node) const arrow::acero::ExecNode::NodeVector& inputs) { std::vector labels(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { @@ -64,6 +74,7 @@ inline bool std_has(const T& container, const V& val) { } // namespace namespace arrow::acero { +<<<<<<< HEAD namespace sorted_merge { @@ -72,10 +83,55 @@ namespace sorted_merge { using UnmaterializedSlice = arrow::acero::UnmaterializedSlice<1>; using UnmaterializedCompositeTable = arrow::acero::UnmaterializedCompositeTable<1>; +======= +namespace { +>>>>>>> b34c999b6 (Create sorted merge node) using row_index_t = uint64_t; using time_unit_t = uint64_t; using col_index_t = int; +<<<<<<< HEAD +======= +template ::BuilderType> +enable_if_boolean BuilderAppend(Builder& builder, + const std::shared_ptr& source, + size_t row) { + if (source->IsNull(row)) { + builder.UnsafeAppendNull(); + return Status::OK(); + } + builder.UnsafeAppend(bit_util::GetBit(source->template GetValues(1), row)); + return Status::OK(); +} + +template ::BuilderType> +arrow::enable_if_t::value && !is_boolean_type::value, + Status> +BuilderAppend(Builder& builder, const std::shared_ptr& source, size_t row) { + if (source->IsNull(row)) { + builder.UnsafeAppendNull(); + return Status::OK(); + } + using CType = typename TypeTraits::CType; + builder.UnsafeAppend(source->template GetValues(1)[row]); + return Status::OK(); +} + +template ::BuilderType> +enable_if_base_binary BuilderAppend( + Builder& builder, const std::shared_ptr& source, size_t row) { + if (source->IsNull(row)) { + return builder.AppendNull(); + } + using offset_type = typename Type::offset_type; + const uint8_t* data = source->buffers[2]->data(); + const offset_type* offsets = source->GetValues(1); + const offset_type offset0 = offsets[row]; + const offset_type offset1 = offsets[row + 1]; + return builder.Append(data + offset0, offset1 - offset0); +} + +>>>>>>> b34c999b6 (Create sorted merge node) #define NEW_TASK true #define POISON_PILL false @@ -94,8 +150,148 @@ class BackpressureController : public BackpressureControl { std::atomic& backpressure_counter_; }; +<<<<<<< HEAD /// InputState correponds to an input. Input record batches are queued up in InputState /// until processed and turned into output record batches. +======= +class UnmaterializedTable { + struct UnmaterializedRow { + arrow::RecordBatch* batch; + size_t rowNumber; + }; + + public: + struct UnmaterializedSlice { + std::shared_ptr batch; + int64_t start; + int64_t end; + + inline int64_t length() const { return end - start; } + }; + + explicit UnmaterializedTable(const std::shared_ptr& schema_) + : schema(schema_), ptr2Ref{} {} + + void addSlice(const UnmaterializedSlice& slice) { + addRecordBatchRef(slice.batch); + auto t = std::make_tuple(slice.batch.get(), slice.start, slice.end); + slices.push_back(std::move(t)); + numRows += slice.end - slice.start; + } + + void addRow(const std::shared_ptr& batch, size_t rowNumber) { + addRecordBatchRef(batch); + auto t = std::make_tuple(batch.get(), rowNumber, rowNumber + 1); + slices.emplace_back(std::move(t)); + ++numRows; + } + + arrow::Result> materialize() { + // Don't build empty batches + if (empty()) { + return nullptr; + } + DCHECK_LE(getNumRows(), (uint64_t)std::numeric_limits::max()); + std::vector> arrays(schema->num_fields()); + + // https://github.com/apache/arrow/blob/2455bc07e09cd5341d1fabdb293afbd07682f0b2/cpp/src/arrow/acero/asof_join_node.cc#L1089C1-L1096C4 +#define SORTED_MERGE_MATERIALIZE_CASE(id) \ + case arrow::Type::id: { \ + using T = typename arrow::TypeIdTraits::Type; \ + ARROW_ASSIGN_OR_RAISE(arrays.at(iCol), materializeColumn(fieldType, iCol)); \ + break; \ + } + + // Build the arrays column-by-column from the rows + for (int iCol = 0; iCol < schema->num_fields(); ++iCol) { + const std::shared_ptr& field = schema->field(iCol); + const auto& fieldType = field->type(); + + switch (fieldType->id()) { + SORTED_MERGE_MATERIALIZE_CASE(BOOL) + SORTED_MERGE_MATERIALIZE_CASE(INT8) + SORTED_MERGE_MATERIALIZE_CASE(INT16) + SORTED_MERGE_MATERIALIZE_CASE(INT32) + SORTED_MERGE_MATERIALIZE_CASE(INT64) + SORTED_MERGE_MATERIALIZE_CASE(UINT8) + SORTED_MERGE_MATERIALIZE_CASE(UINT16) + SORTED_MERGE_MATERIALIZE_CASE(UINT32) + SORTED_MERGE_MATERIALIZE_CASE(UINT64) + SORTED_MERGE_MATERIALIZE_CASE(FLOAT) + SORTED_MERGE_MATERIALIZE_CASE(DOUBLE) + SORTED_MERGE_MATERIALIZE_CASE(DATE32) + SORTED_MERGE_MATERIALIZE_CASE(DATE64) + SORTED_MERGE_MATERIALIZE_CASE(TIME32) + SORTED_MERGE_MATERIALIZE_CASE(TIME64) + SORTED_MERGE_MATERIALIZE_CASE(TIMESTAMP) + SORTED_MERGE_MATERIALIZE_CASE(STRING) + SORTED_MERGE_MATERIALIZE_CASE(LARGE_STRING) + SORTED_MERGE_MATERIALIZE_CASE(BINARY) + SORTED_MERGE_MATERIALIZE_CASE(LARGE_BINARY) + default: + return arrow::Status::Invalid("Unsupported data type ", + field->type()->ToString(), " for field ", + field->name()); + } + } + +#undef SORTED_MERGE_MATERIALIZE_CASE + + std::shared_ptr r = + arrow::RecordBatch::Make(schema, (int64_t)numRows, arrays); + return r; + } + + size_t getNumRows() const { return numRows; } + size_t empty() const { return numRows == 0; } + + private: + std::shared_ptr schema; + /// A map from address of a record batch to the record batch. Used to + /// maintain the lifetime of the record batch in case it goes out of scope + /// by the main exec node thread + std::unordered_map> ptr2Ref; + + std::vector> slices = {}; + size_t numRows = 0; + + template ::BuilderType> + arrow::Result> materializeColumn( + const std::shared_ptr& type, int iCol, + arrow::MemoryPool* pool = arrow::default_memory_pool()) { + ARROW_ASSIGN_OR_RAISE(auto builderPtr, arrow::MakeBuilder(type, pool)); + Builder& builder = *arrow::internal::checked_cast(builderPtr.get()); + ARROW_RETURN_NOT_OK(builder.Reserve(numRows)); + + for (const auto& [batch, start, end] : slices) { + if (batch) { + for (int64_t rowNum = start; rowNum < end; ++rowNum) { + arrow::Status st = + BuilderAppend(builder, batch->column_data(iCol), rowNum); + ARROW_RETURN_NOT_OK(st); + } + } else { + for (int64_t rowNum = start; rowNum < end; ++rowNum) { + ARROW_RETURN_NOT_OK(builder.AppendNull()); + } + } + } + std::shared_ptr result; + ARROW_RETURN_NOT_OK(builder.Finish(&result)); + return Result{std::move(result)}; + } + + void addRecordBatchRef(const std::shared_ptr& ref) { + if (!ptr2Ref.count((uintptr_t)ref.get())) { + ptr2Ref[(uintptr_t)ref.get()] = ref; + } + } +}; + +/// InputState correponds to an input +/// Input record batches are queued up in InputState until processed and +/// turned into output record batches. +>>>>>>> b34c999b6 (Create sorted merge node) class InputState { public: InputState(size_t index, BackpressureHandler handler, @@ -107,16 +303,26 @@ class InputState { time_type_id_(schema_->fields()[time_col_index_]->type()->id()) {} template +<<<<<<< HEAD static arrow::Result Make(size_t index, arrow::acero::ExecNode* node, +======= + static arrow::Result Make(size_t index, arrow::acero::ExecNode* input, +>>>>>>> b34c999b6 (Create sorted merge node) arrow::acero::ExecNode* output, std::atomic& backpressure_counter, const std::shared_ptr& schema, const col_index_t time_col_index) { constexpr size_t low_threshold = 4, high_threshold = 8; std::unique_ptr backpressure_control = +<<<<<<< HEAD std::make_unique(node, output, backpressure_counter); ARROW_ASSIGN_OR_RAISE(auto handler, BackpressureHandler::Make(low_threshold, high_threshold, +======= + std::make_unique(input, output, backpressure_counter); + ARROW_ASSIGN_OR_RAISE(auto handler, + BackpressureHandler::Make(input, low_threshold, high_threshold, +>>>>>>> b34c999b6 (Create sorted merge node) std::move(backpressure_control))); return PtrType(new InputState(index, std::move(handler), schema, time_col_index)); } @@ -155,16 +361,28 @@ class InputState { } inline time_unit_t GetLatestTime() const { +<<<<<<< HEAD return GetTime(GetLatestBatch().get(), time_type_id_, time_col_index_, latest_ref_row_); +======= + return GetTime(GetLatestBatch().get(), latest_ref_row_); + } + + inline time_unit_t GetTime(const arrow::RecordBatch* batch, row_index_t row) const { + return get_time(batch, time_type_id_, time_col_index_, row); +>>>>>>> b34c999b6 (Create sorted merge node) } #undef LATEST_VAL_CASE bool Finished() const { return batches_processed_ == total_batches_; } +<<<<<<< HEAD arrow::Result>> Advance() { +======= + arrow::Result Advance() { +>>>>>>> b34c999b6 (Create sorted merge node) // Advance the row until a new time is encountered or the record batch // ends. This will return a range of {-1, -1} and a nullptr if there is // no input @@ -173,18 +391,30 @@ class InputState { (latest_ref_row_ > 0 /*short circuit the lock on the queue*/) || !queue_.Empty(); if (!active) { +<<<<<<< HEAD return std::make_pair(UnmaterializedSlice(), nullptr); +======= + return UnmaterializedTable::UnmaterializedSlice{nullptr, -1, -1}; +>>>>>>> b34c999b6 (Create sorted merge node) } row_index_t start = latest_ref_row_; row_index_t end = latest_ref_row_; time_unit_t startTime = GetLatestTime(); std::shared_ptr batch = queue_.UnsyncFront(); +<<<<<<< HEAD auto rows_in_batch = (row_index_t)batch->num_rows(); while (GetLatestTime() == startTime) { end = ++latest_ref_row_; if (latest_ref_row_ >= rows_in_batch) { +======= + auto rowsInBatch = (row_index_t)batch->num_rows(); + + while (GetLatestTime() == startTime) { + end = ++latest_ref_row_; + if (latest_ref_row_ >= rowsInBatch) { +>>>>>>> b34c999b6 (Create sorted merge node) // hit the end of the batch, need to get the next batch if // possible. ++batches_processed_; @@ -197,11 +427,43 @@ class InputState { break; } } +<<<<<<< HEAD UnmaterializedSlice slice; slice.num_components = 1; slice.components[0] = CompositeEntry{batch.get(), start, end}; return std::make_pair(slice, batch); +======= + return UnmaterializedTable::UnmaterializedSlice{batch, // + static_cast(start), // + static_cast(end)}; + } + + arrow::Result AdvanceOnce() { + // Try advancing to the next row and update latest_ref_row_ + // Returns true if able to advance, false if not. + bool have_active_batch = + (latest_ref_row_ > 0 /*short circuit the lock on the queue*/) || !queue_.Empty(); + + if (have_active_batch) { + time_unit_t next_time = GetLatestTime(); + latest_time_ = next_time; + auto rowsInBatch = (row_index_t)queue_.UnsyncFront()->num_rows(); + // If we have an active batch + if (++latest_ref_row_ >= rowsInBatch) { + // hit the end of the batch, need to get the next batch if + // possible. + ++batches_processed_; + latest_ref_row_ = 0; + have_active_batch &= !queue_.TryPop(); + if (have_active_batch) { + // empty batches disallowed + DCHECK_GT(queue_.UnsyncFront()->num_rows(), 0); + } + } + } + return have_active_batch; +>>>>>>> b34c999b6 (Create sorted merge node) } arrow::Status Push(const std::shared_ptr& rb) { @@ -268,21 +530,38 @@ class SortedMergeNode : public ExecNode { std::vector inputs, std::shared_ptr output_schema, arrow::Ordering new_ordering) +<<<<<<< HEAD : ExecNode(plan, inputs, GetInputLabels(inputs), std::move(output_schema)), ordering_(std::move(new_ordering)), input_counter(inputs_.size()), output_counter(inputs_.size()), process_thread() { +======= + : ExecNode(plan, inputs, getInputLabels(inputs), std::move(output_schema)), + ordering_(std::move(new_ordering)), + inputCounter(inputs_.size()), + outputCounter(inputs_.size()), + processThread() { +>>>>>>> b34c999b6 (Create sorted merge node) SetLabel("sorted_merge"); } ~SortedMergeNode() override { +<<<<<<< HEAD process_queue.Push( POISON_PILL); // poison pill // We might create a temporary (such as to inspect the output // schema), in which case there isn't anything to join if (process_thread.joinable()) { process_thread.join(); +======= + processQueue.Push( + POISON_PILL); // poison pill + // We might create a temporary (such as to inspect the output + // schema), in which case there isn't anything to join + if (processThread.joinable()) { + processThread.join(); +>>>>>>> b34c999b6 (Create sorted merge node) } } @@ -341,7 +620,11 @@ class SortedMergeNode : public ExecNode { ARROW_ASSIGN_OR_RAISE(auto input_state, InputState::Make>( +<<<<<<< HEAD i, input, this, backpressure_counter, schema, +======= + i, input, this, backpressureCounter, schema, +>>>>>>> b34c999b6 (Create sorted merge node) schema->GetFieldIndex(*ref.name()))); state.push_back(std::move(input_state)); } @@ -357,9 +640,15 @@ class SortedMergeNode : public ExecNode { // Push into the queue. Note that we don't need to lock since // InputState's ConcurrentQueue manages locking +<<<<<<< HEAD input_counter[index] += rb->num_rows(); ARROW_RETURN_NOT_OK(state[index]->Push(rb)); process_queue.Push(NEW_TASK); +======= + inputCounter[index] += rb->num_rows(); + ARROW_RETURN_NOT_OK(state[index]->Push(rb)); + processQueue.Push(NEW_TASK); +>>>>>>> b34c999b6 (Create sorted merge node) return Status::OK(); } @@ -372,11 +661,16 @@ class SortedMergeNode : public ExecNode { state.at(k)->set_total_batches(total_batches); } // Trigger a final process call for stragglers +<<<<<<< HEAD process_queue.Push(NEW_TASK); +======= + processQueue.Push(NEW_TASK); +>>>>>>> b34c999b6 (Create sorted merge node) return Status::OK(); } arrow::Status StartProducing() override { +<<<<<<< HEAD ARROW_ASSIGN_OR_RAISE(process_task, plan_->query_context()->BeginExternalTask( "SortedMergeNode::ProcessThread")); if (!process_task.is_valid()) { @@ -384,12 +678,26 @@ class SortedMergeNode : public ExecNode { return Status::OK(); } process_thread = std::thread(&SortedMergeNode::StartPoller, this); +======= + ARROW_ASSIGN_OR_RAISE(processTask, plan_->query_context()->BeginExternalTask( + "SortedMergeNode::ProcessThread")); + if (!processTask.is_valid()) { + // Plan has already aborted. Do not start process thread + return Status::OK(); + } + processThread = std::thread(&SortedMergeNode::startPoller, this); +>>>>>>> b34c999b6 (Create sorted merge node) return Status::OK(); } arrow::Status StopProducingImpl() override { +<<<<<<< HEAD process_queue.Clear(); process_queue.Push(POISON_PILL); +======= + processQueue.Clear(); + processQueue.Push(POISON_PILL); +>>>>>>> b34c999b6 (Create sorted merge node) return Status::OK(); } @@ -406,27 +714,49 @@ class SortedMergeNode : public ExecNode { private: void EndFromProcessThread(arrow::Status st = arrow::Status::OK()) { +<<<<<<< HEAD ARROW_CHECK(!cleanup_started); for (size_t i = 0; i < input_counter.size(); ++i) { ARROW_CHECK(input_counter[i] == output_counter[i]) << input_counter[i] << " != " << output_counter[i]; +======= + ARROW_CHECK(!cleanupStarted); + for (size_t i = 0; i < inputCounter.size(); ++i) { + ARROW_CHECK(inputCounter[i] == outputCounter[i]) + << inputCounter[i] << " != " << outputCounter[i]; +>>>>>>> b34c999b6 (Create sorted merge node) } ARROW_UNUSED( plan_->query_context()->executor()->Spawn([this, st = std::move(st)]() mutable { +<<<<<<< HEAD Defer cleanup([this, &st]() { process_task.MarkFinished(st); }); if (st.ok()) { st = output_->InputFinished(this, batches_produced); +======= + Defer cleanup([this, &st]() { processTask.MarkFinished(st); }); + if (st.ok()) { + st = output_->InputFinished(this, batchesProduced); +>>>>>>> b34c999b6 (Create sorted merge node) } })); } +<<<<<<< HEAD bool CheckEnded() { bool all_finished = true; for (const auto& s : state) { all_finished &= s->Finished(); } if (all_finished) { +======= + bool checkEnded() { + bool allFinished = true; + for (const auto& s : state) { + allFinished &= s->Finished(); + } + if (allFinished) { +>>>>>>> b34c999b6 (Create sorted merge node) EndFromProcessThread(); return false; } @@ -449,6 +779,7 @@ class SortedMergeNode : public ExecNode { [](const std::shared_ptr& s) { return s->Finished(); }), heap.end()); // Currently we only support one sort key +<<<<<<< HEAD const auto sort_col = *ordering_.sort_keys().at(0).target.name(); const auto comp = InputStateComparator(); std::make_heap(heap.begin(), heap.end(), comp); @@ -486,24 +817,68 @@ class SortedMergeNode : public ExecNode { } if (next_item->Finished() || next_item->Empty()) { +======= + const auto sortCol = *ordering_.sort_keys().at(0).target.name(); + const auto comp = InputStateComparator(); + std::make_heap(heap.begin(), heap.end(), comp); + + UnmaterializedTable output(output_schema()); + + // Generate rows until we run out of data or we exceed the target output + // size + while (!heap.empty() && output.getNumRows() < kTargetOutputBatchSize) { + std::pop_heap(heap.begin(), heap.end(), comp); + + auto& nextItem = heap.back(); + time_unit_t latestTime = std::numeric_limits::min(); + time_unit_t newTime = nextItem->GetLatestTime(); + ARROW_CHECK(newTime >= latestTime) << "Input state " << nextItem->index() + << " has out of order data. newTime=" << newTime + << " latestTime=" << latestTime; + + latestTime = newTime; + ARROW_ASSIGN_OR_RAISE(UnmaterializedTable::UnmaterializedSlice slice, + nextItem->Advance()); + + if (slice.length() > 0) { + outputCounter[nextItem->index()] += slice.length(); + output.addSlice(slice); + } + + if (nextItem->Finished() || nextItem->Empty()) { +>>>>>>> b34c999b6 (Create sorted merge node) heap.pop_back(); } std::make_heap(heap.begin(), heap.end(), comp); } // Emit the batch +<<<<<<< HEAD if (output.Size() == 0) { return nullptr; } auto result = output.Materialize(); +======= + if (output.getNumRows() == 0) { + return nullptr; + } + + auto result = output.materialize(); +>>>>>>> b34c999b6 (Create sorted merge node) return result; } /// Gets a batch. Returns true if there is more data to process, false if we /// are done or an error occurred +<<<<<<< HEAD bool PollOnce() { std::lock_guard guard(gate); if (!CheckEnded()) { +======= + bool pollOnce() { + std::lock_guard guard(gate); + if (!checkEnded()) { +>>>>>>> b34c999b6 (Create sorted merge node) return false; } @@ -517,7 +892,11 @@ class SortedMergeNode : public ExecNode { break; } ExecBatch out_b(*out_rb); +<<<<<<< HEAD out_b.index = batches_produced++; +======= + out_b.index = batchesProduced++; +>>>>>>> b34c999b6 (Create sorted merge node) Status st = output_->InputReceived(this, std::move(out_b)); if (!st.ok()) { ARROW_LOG(FATAL) << "Error in output_::InputReceived: " << st.ToString(); @@ -536,7 +915,11 @@ class SortedMergeNode : public ExecNode { // It may happen here in cases where InputFinished was called before // we were finished producing results (so we didn't know the output // size at that time) +<<<<<<< HEAD if (!CheckEnded()) { +======= + if (!checkEnded()) { +>>>>>>> b34c999b6 (Create sorted merge node) return false; } @@ -545,6 +928,7 @@ class SortedMergeNode : public ExecNode { return true; } +<<<<<<< HEAD void EmitBatches() { while (true) { // Implementation note: If the queue is empty, we will block here @@ -553,18 +937,33 @@ class SortedMergeNode : public ExecNode { } // Either we're out of data or something went wrong if (!PollOnce()) { +======= + void emitBatches() { + while (true) { + // Implementation note: If the queue is empty, we will block here + if (processQueue.Pop() == POISON_PILL) { + EndFromProcessThread(); + } + // Either we're out of data or something went wrong + if (!pollOnce()) { +>>>>>>> b34c999b6 (Create sorted merge node) return; } } } /// The entry point for processThread +<<<<<<< HEAD static void StartPoller(SortedMergeNode* node) { node->EmitBatches(); } +======= + static void startPoller(SortedMergeNode* node) { node->emitBatches(); } +>>>>>>> b34c999b6 (Create sorted merge node) arrow::Ordering ordering_; // Each input state corresponds to an input (e.g. a parquet data file) std::vector> state; +<<<<<<< HEAD std::vector input_counter; std::vector output_counter; std::mutex gate; @@ -582,6 +981,25 @@ class SortedMergeNode : public ExecNode { // input states and emit batches std::thread process_thread; arrow::Future<> process_task; +======= + std::vector inputCounter; + std::vector outputCounter; + std::mutex gate; + + std::atomic cleanupStarted{false}; + + // Backpressure counter common to all input states + std::atomic backpressureCounter; + + std::atomic batchesProduced{0}; + + // Queue to trigger processing of a given input. False acts as a poison pill + ConcurrentQueue processQueue; + // Once StartProducing is called, we initialize this thread to poll the + // input states and emit batches + std::thread processThread; + arrow::Future<> processTask; +>>>>>>> b34c999b6 (Create sorted merge node) // Map arg index --> completion counter std::vector counter_; @@ -590,10 +1008,21 @@ class SortedMergeNode : public ExecNode { std::mutex mutex_; std::atomic total_batches_{0}; }; +<<<<<<< HEAD +======= +} // namespace + +namespace internal { +void RegisterSortedMergeNode(ExecFactoryRegistry* registry) { + DCHECK_OK(registry->AddFactory("sorted_merge", SortedMergeNode::Make)); +} +} // namespace internal +>>>>>>> b34c999b6 (Create sorted merge node) #undef NEW_TASK #undef POISON_PILL +<<<<<<< HEAD } // namespace sorted_merge namespace internal { @@ -602,4 +1031,6 @@ void RegisterSortedMergeNode(ExecFactoryRegistry* registry) { } } // namespace internal +======= +>>>>>>> b34c999b6 (Create sorted merge node) } // namespace arrow::acero diff --git a/cpp/src/arrow/acero/time_series_util.cc b/cpp/src/arrow/acero/time_series_util.cc index 5fb3445a61307..486906a13c5b8 100644 --- a/cpp/src/arrow/acero/time_series_util.cc +++ b/cpp/src/arrow/acero/time_series_util.cc @@ -15,6 +15,10 @@ // specific language governing permissions and limitations // under the License. +<<<<<<< HEAD +======= + +>>>>>>> b34c999b6 (Create sorted merge node) #include "arrow/array/data.h" #include "arrow/acero/time_series_util.h" @@ -24,12 +28,20 @@ namespace arrow::acero { // normalize the value to 64-bits while preserving ordering of values template ::value, bool>> +<<<<<<< HEAD inline uint64_t NormalizeTime(T t) { +======= +static inline uint64_t get_time_normalized(T t) { +>>>>>>> b34c999b6 (Create sorted merge node) uint64_t bias = std::is_signed::value ? (uint64_t)1 << (8 * sizeof(T) - 1) : 0; return t < 0 ? static_cast(t + bias) : static_cast(t); } +<<<<<<< HEAD uint64_t GetTime(const RecordBatch* batch, Type::type time_type, int col, uint64_t row) { +======= +uint64_t get_time(const RecordBatch* batch, Type::type time_type, int col, uint64_t row) { +>>>>>>> b34c999b6 (Create sorted merge node) #define LATEST_VAL_CASE(id, val) \ case Type::id: { \ using T = typename TypeIdTraits::Type; \ @@ -39,6 +51,7 @@ uint64_t GetTime(const RecordBatch* batch, Type::type time_type, int col, uint64 auto data = batch->column_data(col); switch (time_type) { +<<<<<<< HEAD LATEST_VAL_CASE(INT8, NormalizeTime) LATEST_VAL_CASE(INT16, NormalizeTime) LATEST_VAL_CASE(INT32, NormalizeTime) @@ -52,6 +65,21 @@ uint64_t GetTime(const RecordBatch* batch, Type::type time_type, int col, uint64 LATEST_VAL_CASE(TIME32, NormalizeTime) LATEST_VAL_CASE(TIME64, NormalizeTime) LATEST_VAL_CASE(TIMESTAMP, NormalizeTime) +======= + LATEST_VAL_CASE(INT8, get_time_normalized) + LATEST_VAL_CASE(INT16, get_time_normalized) + LATEST_VAL_CASE(INT32, get_time_normalized) + LATEST_VAL_CASE(INT64, get_time_normalized) + LATEST_VAL_CASE(UINT8, get_time_normalized) + LATEST_VAL_CASE(UINT16, get_time_normalized) + LATEST_VAL_CASE(UINT32, get_time_normalized) + LATEST_VAL_CASE(UINT64, get_time_normalized) + LATEST_VAL_CASE(DATE32, get_time_normalized) + LATEST_VAL_CASE(DATE64, get_time_normalized) + LATEST_VAL_CASE(TIME32, get_time_normalized) + LATEST_VAL_CASE(TIME64, get_time_normalized) + LATEST_VAL_CASE(TIMESTAMP, get_time_normalized) +>>>>>>> b34c999b6 (Create sorted merge node) default: DCHECK(false); return 0; // cannot happen diff --git a/cpp/src/arrow/acero/time_series_util.h b/cpp/src/arrow/acero/time_series_util.h index c74cb6b712d7b..82796fb716b5b 100644 --- a/cpp/src/arrow/acero/time_series_util.h +++ b/cpp/src/arrow/acero/time_series_util.h @@ -24,8 +24,14 @@ namespace arrow::acero { // normalize the value to 64-bits while preserving ordering of values template ::value, bool> = true> +<<<<<<< HEAD inline uint64_t NormalizeTime(T t); uint64_t GetTime(const RecordBatch* batch, Type::type time_type, int col, uint64_t row); +======= +static inline uint64_t get_time_normalized(T t); + +uint64_t get_time(const RecordBatch* batch, Type::type time_type, int col, uint64_t row); +>>>>>>> b34c999b6 (Create sorted merge node) } // namespace arrow::acero