diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index 1f8b6cc4882cf..39df6bfdfc6bf 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -36,6 +36,7 @@ #include "arrow/util/iterator.h" #include "arrow/util/logging.h" #include "arrow/util/range.h" +#include "arrow/util/thread_pool.h" #include "arrow/util/tracing_internal.h" #include "parquet/arrow/reader.h" #include "parquet/arrow/schema.h" @@ -633,10 +634,15 @@ Result ParquetFileFormat::ScanBatchesAsync( kParquetTypeName, options.get(), default_fragment_scan_options)); int batch_readahead = options->batch_readahead; int64_t rows_to_readahead = batch_readahead * options->batch_size; - ARROW_ASSIGN_OR_RAISE(auto generator, - reader->GetRecordBatchGenerator( - reader, row_groups, column_projection, - ::arrow::internal::GetCpuThreadPool(), rows_to_readahead)); + // Modified this to pass the executor in scan_options instead of always using the + // default CPU thread pool. + // XXX Should we get it from options->fragment_scan_options instead?? + auto cpu_executor = options->exec_context.executor() + ? options->exec_context.executor() + : ::arrow::internal::GetCpuThreadPool(); + ARROW_ASSIGN_OR_RAISE(auto generator, reader->GetRecordBatchGenerator( + reader, row_groups, column_projection, + cpu_executor, rows_to_readahead)); RecordBatchGenerator sliced = SlicingGenerator(std::move(generator), options->batch_size); if (batch_readahead == 0) { diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index bf626826d4d1b..94be9cb7969d0 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -22,6 +22,7 @@ #include #include +#include "arrow/acero/exec_plan.h" #include "arrow/compute/api_scalar.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/parquet_encryption_config.h" @@ -910,5 +911,53 @@ TEST_F(TestParquetFileFormat, MultithreadedScanRegression) { } } +TEST_F(TestParquetFileFormat, MultithreadedComputeRegression) { + // GH-43694: Test similar situation as MultithreadedScanRegression but with + // the exec context instead + + auto reader = MakeGeneratedRecordBatch(schema({field("utf8", utf8())}), 10000, 100); + ASSERT_OK_AND_ASSIGN(auto buffer, ParquetFormatHelper::Write(reader.get())); + + std::vector> completes; + std::vector> pools; + + for (int idx = 0; idx < 2; ++idx) { + auto buffer_reader = std::make_shared(buffer); + auto source = std::make_shared(buffer_reader, buffer->size()); + auto fragment = MakeFragment(*source); + std::shared_ptr scanner; + + { + auto options = std::make_shared(); + ASSERT_OK_AND_ASSIGN(auto thread_pool, arrow::internal::ThreadPool::Make(1)); + pools.emplace_back(thread_pool); + options->exec_context = + ::arrow::ExecContext(::arrow::default_memory_pool(), pools.back().get()); + auto fragment_scan_options = std::make_shared(); + fragment_scan_options->arrow_reader_properties->set_pre_buffer(true); + + options->fragment_scan_options = fragment_scan_options; + ScannerBuilder builder(ArithmeticDatasetFixture::schema(), fragment, options); + + ASSERT_OK(builder.UseThreads(true)); + ASSERT_OK(builder.BatchSize(10000)); + ASSERT_OK_AND_ASSIGN(scanner, builder.Finish()); + } + + ASSERT_OK_AND_ASSIGN(auto batch, scanner->Head(10000)); + [[maybe_unused]] auto fut = scanner->ScanBatchesUnorderedAsync(); + // Random ReadAsync calls, generate some futures to make the state machine + // more complex. + for (int yy = 0; yy < 16; yy++) { + completes.emplace_back(buffer_reader->ReadAsync(::arrow::io::IOContext(), 0, 1001)); + } + scanner = nullptr; + } + + for (auto& f : completes) { + f.Wait(); + } +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index a856a792a264f..5c10dfc6ac5fb 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -355,8 +355,10 @@ class OneShotFragment : public Fragment { ARROW_ASSIGN_OR_RAISE( auto background_gen, MakeBackgroundGenerator(std::move(batch_it_), options->io_context.executor())); - return MakeTransferredGenerator(std::move(background_gen), - ::arrow::internal::GetCpuThreadPool()); + auto cpu_executor = options->exec_context.executor() + ? options->exec_context.executor() + : ::arrow::internal::GetCpuThreadPool(); + return MakeTransferredGenerator(std::move(background_gen), cpu_executor); } std::string type_name() const override { return "one-shot"; } @@ -382,7 +384,7 @@ Result AsyncScanner::ScanBatches() { [this](::arrow::internal::Executor* executor) { return ScanBatchesAsync(executor); }, - scan_options_->use_threads); + scan_options_->use_threads, this->async_cpu_executor()); } Result AsyncScanner::ScanBatchesUnordered() { @@ -390,7 +392,7 @@ Result AsyncScanner::ScanBatchesUnordered() { [this](::arrow::internal::Executor* executor) { return ScanBatchesUnorderedAsync(executor); }, - scan_options_->use_threads); + scan_options_->use_threads, this->async_cpu_executor()); } Result> AsyncScanner::ToTable() { @@ -400,7 +402,7 @@ Result> AsyncScanner::ToTable() { } Result AsyncScanner::ScanBatchesUnorderedAsync() { - return ScanBatchesUnorderedAsync(::arrow::internal::GetCpuThreadPool(), + return ScanBatchesUnorderedAsync(this->async_cpu_executor(), /*sequence_fragments=*/false); } @@ -601,7 +603,7 @@ Result> AsyncScanner::Head(int64_t num_rows) { } Result AsyncScanner::ScanBatchesAsync() { - return ScanBatchesAsync(::arrow::internal::GetCpuThreadPool()); + return ScanBatchesAsync(this->async_cpu_executor()); } Result AsyncScanner::ScanBatchesAsync( @@ -778,7 +780,7 @@ Future AsyncScanner::CountRowsAsync(Executor* executor) { } Future AsyncScanner::CountRowsAsync() { - return CountRowsAsync(::arrow::internal::GetCpuThreadPool()); + return CountRowsAsync(this->async_cpu_executor()); } Result AsyncScanner::CountRows() { diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index d2de267897180..1c605c1bf21f6 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -107,6 +107,11 @@ struct ARROW_DS_EXPORT ScanOptions { /// Note: The IOContext executor will be ignored if use_threads is set to false io::IOContext io_context; + /// ExecContext for any CPU tasks + /// + /// Note: The ExecContext executor will be ignored if use_threads is set to false + compute::ExecContext exec_context; + /// If true the scanner will scan in parallel /// /// Note: If true, this will use threads from both the cpu_executor and the @@ -442,6 +447,11 @@ class ARROW_DS_EXPORT Scanner { TaggedRecordBatchIterator scan); const std::shared_ptr scan_options_; + + ::arrow::internal::Executor* async_cpu_executor() const { + return scan_options_->exec_context.executor() ? scan_options_->exec_context.executor() + : ::arrow::internal::GetCpuThreadPool(); + } }; /// \brief ScannerBuilder is a factory class to construct a Scanner. It is used diff --git a/cpp/src/arrow/util/thread_pool.h b/cpp/src/arrow/util/thread_pool.h index 44b1e227b0e5f..0d3babd38ed20 100644 --- a/cpp/src/arrow/util/thread_pool.h +++ b/cpp/src/arrow/util/thread_pool.h @@ -591,6 +591,21 @@ typename Fut::SyncType RunSynchronously(FnOnce get_future, } } +template +Iterator IterateSynchronously( + FnOnce()>>(Executor*)> get_gen, bool use_threads, + Executor* executor) { + if (use_threads) { + auto maybe_gen = std::move(get_gen)(executor); + if (!maybe_gen.ok()) { + return MakeErrorIterator(maybe_gen.status()); + } + return MakeGeneratorIterator(*maybe_gen); + } else { + return SerialExecutor::IterateGenerator(std::move(get_gen)); + } +} + /// \brief Potentially iterate an async generator serially (if use_threads is false) /// \see IterateGenerator /// @@ -605,15 +620,7 @@ typename Fut::SyncType RunSynchronously(FnOnce get_future, template Iterator IterateSynchronously( FnOnce()>>(Executor*)> get_gen, bool use_threads) { - if (use_threads) { - auto maybe_gen = std::move(get_gen)(GetCpuThreadPool()); - if (!maybe_gen.ok()) { - return MakeErrorIterator(maybe_gen.status()); - } - return MakeGeneratorIterator(*maybe_gen); - } else { - return SerialExecutor::IterateGenerator(std::move(get_gen)); - } + return IterateSynchronously(std::move(get_gen), use_threads, GetCpuThreadPool()); } } // namespace internal