diff --git a/velox/experimental/wave/common/CMakeLists.txt b/velox/experimental/wave/common/CMakeLists.txt index db1f54f25568c..b707cb9efe271 100644 --- a/velox/experimental/wave/common/CMakeLists.txt +++ b/velox/experimental/wave/common/CMakeLists.txt @@ -16,8 +16,10 @@ velox_add_library( velox_wave_common GpuArena.cpp Buffer.cpp + Compile.cu Cuda.cu Exception.cpp + KernelCache.cpp Type.cpp ResultStaging.cpp) diff --git a/velox/experimental/wave/common/Compile.cu b/velox/experimental/wave/common/Compile.cu new file mode 100644 index 0000000000000..b209caef066cb --- /dev/null +++ b/velox/experimental/wave/common/Compile.cu @@ -0,0 +1,184 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "velox/experimental/wave/common/Cuda.h" +#include "velox/experimental/wave/common/CudaUtil.cuh" +#include "velox/experimental/wave/common/Exception.h" + +DEFINE_string( + wavegen_architecture, + "compute_70", + "--gpu-architecture flag for generated code"); + +namespace facebook::velox::wave { + +namespace { +void nvrtcCheck(nvrtcResult result) { + if (result != NVRTC_SUCCESS) { + waveError(nvrtcGetErrorString(result)); + } +} + +class CompiledModuleImpl : public CompiledModule { + public: + CompiledModuleImpl(CUmodule module, std::vector kernels) + : module_(module), kernels_(std::move(kernels)) {} + + ~CompiledModuleImpl() { + auto result = cuModuleUnload(module_); + if (result != CUDA_SUCCESS) { + LOG(ERROR) << "Error in unloading module " << result; + } + } + + void launch( + int32_t kernelIdx, + int32_t numBlocks, + int32_t numThreads, + int32_t shared, + Stream* stream, + void** args) override; + + KernelInfo info(int32_t kernelIdx) override; + + private: + CUmodule module_; + std::vector kernels_; +}; +} // namespace + +std::shared_ptr CompiledModule::create(const KernelSpec& spec) { + nvrtcProgram prog; + nvrtcCreateProgram( + &prog, + spec.code.c_str(), // buffer + spec.filePath.c_str(), // name + spec.numHeaders, // numHeaders + spec.headers, // headers + spec.headerNames); // includeNames + for (auto& name : spec.entryPoints) { + nvrtcCheck(nvrtcAddNameExpression(prog, name.c_str())); + } + auto architecture = + fmt::format("--gpu-architecture={}", FLAGS_wavegen_architecture); + const char* opts[] = { + architecture.c_str() //, +#if 0 +#ifndef NDEBUG + "-G" +#else + "-O3" +#endif +#endif + }; + auto compileResult = nvrtcCompileProgram( + prog, // prog + sizeof(opts) / sizeof(char*), // numOptions + opts); // options + + size_t logSize; + + nvrtcGetProgramLogSize(prog, &logSize); + std::string log; + log.resize(logSize); + nvrtcGetProgramLog(prog, log.data()); + + if (compileResult != NVRTC_SUCCESS) { + nvrtcDestroyProgram(&prog); + waveError(std::string("Cuda compilation error: ") + log); + } + // Obtain PTX from the program. + size_t ptxSize; + nvrtcCheck(nvrtcGetPTXSize(prog, &ptxSize)); + std::string ptx; + ptx.resize(ptxSize); + nvrtcCheck(nvrtcGetPTX(prog, ptx.data())); + std::vector loweredNames; + for (auto& entry : spec.entryPoints) { + const char* temp; + nvrtcCheck(nvrtcGetLoweredName(prog, entry.c_str(), &temp)); + loweredNames.push_back(std::string(temp)); + } + + nvrtcDestroyProgram(&prog); + CUjit_option options[] = { + CU_JIT_INFO_LOG_BUFFER, + CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, + CU_JIT_ERROR_LOG_BUFFER, + CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES}; + char info[1024]; + char error[1024]; + uint32_t infoSize = sizeof(info); + uint32_t errorSize = sizeof(error); + void* values[] = {info, &infoSize, error, &errorSize}; + + CUmodule module; + auto loadResult = cuModuleLoadDataEx( + &module, ptx.data(), sizeof(values) / sizeof(void*), options, values); + if (loadResult != CUDA_SUCCESS) { + LOG(ERROR) << "Load error " << errorSize << " " << infoSize; + waveError(fmt::format("Error in load module: {} {}", info, error)); + } + std::vector funcs; + for (auto& name : loweredNames) { + funcs.emplace_back(); + CU_CHECK(cuModuleGetFunction(&funcs.back(), module, name.c_str())); + } + return std::make_shared(module, std::move(funcs)); +} + +void CompiledModuleImpl::launch( + int32_t kernelIdx, + int32_t numBlocks, + int32_t numThreads, + int32_t shared, + Stream* stream, + void** args) { + auto result = cuLaunchKernel( + kernels_[kernelIdx], + numBlocks, + 1, + 1, // grid dim + numThreads, + 1, + 1, // block dim + shared, + reinterpret_cast(stream->stream()->stream), + args, + 0); + CU_CHECK(result); +}; + +KernelInfo CompiledModuleImpl::info(int32_t kernelIdx) { + KernelInfo info; + auto f = kernels_[kernelIdx]; + cuFuncGetAttribute(&info.numRegs, CU_FUNC_ATTRIBUTE_NUM_REGS, f); + cuFuncGetAttribute( + &info.sharedMemory, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, f); + cuFuncGetAttribute( + &info.maxThreadsPerBlock, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, f); + int32_t max; + cuOccupancyMaxActiveBlocksPerMultiprocessor(&max, f, 256, 0); + info.maxOccupancy0 = max; + cuOccupancyMaxActiveBlocksPerMultiprocessor(&max, f, 256, 256 * 32); + info.maxOccupancy32 = max; + return info; +} + +} // namespace facebook::velox::wave diff --git a/velox/experimental/wave/common/Cuda.cu b/velox/experimental/wave/common/Cuda.cu index b018c7283a74f..838f4a4b93593 100644 --- a/velox/experimental/wave/common/Cuda.cu +++ b/velox/experimental/wave/common/Cuda.cu @@ -21,10 +21,19 @@ #include "velox/experimental/wave/common/CudaUtil.cuh" #include "velox/experimental/wave/common/Exception.h" +#include #include namespace facebook::velox::wave { +void cuCheck(CUresult result, const char* file, int32_t line) { + if (result != CUDA_SUCCESS) { + const char* str; + cuGetErrorString(result, &str); + waveError(fmt::format("Cuda error: {}:{} {}", file, line, str)); + } +} + void cudaCheck(cudaError_t err, const char* file, int line) { if (err == cudaSuccess) { return; @@ -43,6 +52,91 @@ void cudaCheckFatal(cudaError_t err, const char* file, int line) { exit(1); } +namespace { +std::mutex ctxMutex; +bool driverInited = false; + +// A context for each device. Each is initialized on first use and made the +// primary context for the device. +std::vector contexts; +// Device structs to 1:1 to contexts. +std::vector> devices; + +Device* setDriverDevice(int32_t deviceId) { + if (!driverInited) { + std::lock_guard l(ctxMutex); + CU_CHECK(cuInit(0)); + int32_t cnt; + CU_CHECK(cuDeviceGetCount(&cnt)); + contexts.resize(cnt); + devices.resize(cnt); + if (cnt == 0) { + waveError("No Cuda devices found"); + } + } + if (deviceId >= contexts.size()) { + waveError(fmt::format("Bad device id {}", deviceId)); + } + if (contexts[deviceId] != nullptr) { + cuCtxSetCurrent(contexts[deviceId]); + return devices[deviceId].get(); + } + { + std::lock_guard l(ctxMutex); + CUdevice dev; + CU_CHECK(cuDeviceGet(&dev, deviceId)); + CU_CHECK(cuDevicePrimaryCtxRetain(&contexts[deviceId], dev)); + devices[deviceId] = std::make_unique(deviceId); + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceId)); + auto& device = devices[deviceId]; + device->model = prop.name; + device->major = prop.major; + device->minor = prop.minor; + device->globalMB = prop.totalGlobalMem >> 20; + device->numSM = prop.multiProcessorCount; + device->sharedMemPerSM = prop.sharedMemPerMultiprocessor; + device->L2Size = prop.l2CacheSize; + device->persistingL2MaxSize = prop.persistingL2CacheMaxSize; + } + CU_CHECK(cuCtxSetCurrent(contexts[deviceId])); + return devices[deviceId].get(); +} + +} // namespace + +Device* currentDevice() { + CUcontext ctx; + CU_CHECK(cuCtxGetCurrent(&ctx)); + if (!ctx) { + return nullptr; + } + for (auto i = 0; i < contexts.size(); ++i) { + if (contexts[i] == ctx) { + return devices[i].get(); + } + } + waveError("Device context not found. Inconsistent state."); + return nullptr; +} + +Device* getDevice(int32_t deviceId) { + Device* save = nullptr; + if (driverInited) { + save = currentDevice(); + } + auto* dev = setDriverDevice(deviceId); + if (save) { + setDevice(save); + } + return dev; +} + +void setDevice(Device* device) { + setDriverDevice(device->deviceId); + CUDA_CHECK(cudaSetDevice(device->deviceId)); +} + namespace { class CudaManagedAllocator : public GpuAllocator { public: @@ -106,15 +200,7 @@ GpuAllocator* getHostAllocator(Device* /*device*/) { return allocator; } -// Always returns device 0. -Device* getDevice(int32_t /*preferredDevice*/) { - static Device device(0); - return &device; -} - -void setDevice(Device* device) { - CUDA_CHECK(cudaSetDevice(device->deviceId)); -} +Stream::Stream(std::unique_ptr impl) : stream_(std::move(impl)) {} Stream::Stream() { stream_ = std::make_unique(); @@ -122,7 +208,9 @@ Stream::Stream() { } Stream::~Stream() { - cudaStreamDestroy(stream_->stream); + if (stream_->stream) { + cudaStreamDestroy(stream_->stream); + } } void Stream::wait() { diff --git a/velox/experimental/wave/common/Cuda.h b/velox/experimental/wave/common/Cuda.h index 4735999a4dc4e..0da9bc2491208 100644 --- a/velox/experimental/wave/common/Cuda.h +++ b/velox/experimental/wave/common/Cuda.h @@ -28,20 +28,37 @@ namespace facebook::velox::wave { struct Device { explicit Device(int32_t id) : deviceId(id) {} + std::string toString() const; + int32_t deviceId; + + /// Excerpt from device properties. + std::string model; + int32_t major; + int32_t minor; + int32_t globalMB; + int32_t numSM; + int32_t sharedMemPerSM; + int32_t L2Size; + int32_t persistingL2MaxSize; }; -/// Checks that the machine has the right capability and returns a Device -/// struct. If 'preferredId' is given tries to return a Device on that device -/// id. -Device* getDevice(int32_t preferredId = -1); +/// Checks that the machine has the right capability and returns the device for +/// 'id' +Device* getDevice(int32_t id = 0); + /// Binds subsequent Cuda operations of the calling thread to 'device'. void setDevice(Device* device); +/// Returns the device bound to te calling thread or nullptr if none. +Device* currentDevice(); + struct StreamImpl; class Stream { public: + explicit Stream(std::unique_ptr impl); + Stream(); virtual ~Stream(); @@ -215,6 +232,61 @@ struct KernelInfo { std::string toString() const; }; +/// Specification of code to compile. +struct KernelSpec { + std::string code; + std::vector entryPoints; + std::string filePath; + int32_t numHeaders{0}; + const char** headers; + const char** headerNames{nullptr}; +}; + +/// Represents the result of compilation. Wrapped accessed through +/// CompiledKernel. +struct CompiledModule { + virtual ~CompiledModule() = default; + /// Compiles 'spec' and returns the result. + static std::shared_ptr create(const KernelSpec& spec); + + virtual void launch( + int32_t kernelIdx, + int32_t numBlocks, + int32_t numThreads, + int32_t shared, + Stream* stream, + void** args) = 0; + + /// Returns resource utilization for 'kernelIdx'th entry point. + virtual KernelInfo info(int32_t kernelIdx) = 0; +}; + +using KernelGenFunc = std::function; + +/// Represents a run-time compiled kernel. These are returned +/// immediately from a kernel cache. The compilation takes place +/// in the background. The member functions block until a possibly +/// pending compilation completes. +class CompiledKernel { + public: + virtual ~CompiledKernel() = default; + + /// Returns the compiled kernel for 'key'. Starts background compilation if + /// 'key's kernel is not compiled. Returns lightweight reference to state + /// owned by compiled kernel cache. + static std::unique_ptr getKernel( + const std::string& key, + KernelGenFunc func); + + virtual void launch( + int32_t idx, + int32_t numBlocks, + int32_t numThreads, + int32_t shared, + Stream* stream, + void** args) = 0; +}; + KernelInfo getRegisteredKernelInfo(const char* name); KernelInfo kernelInfo(const void* func); diff --git a/velox/experimental/wave/common/CudaUtil.cuh b/velox/experimental/wave/common/CudaUtil.cuh index b16e32a72f334..d754246e52fb9 100644 --- a/velox/experimental/wave/common/CudaUtil.cuh +++ b/velox/experimental/wave/common/CudaUtil.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include @@ -27,13 +28,20 @@ void cudaCheck(cudaError_t err, const char* file, int line); void cudaCheckFatal(cudaError_t err, const char* file, int line); +void cuCheck(CUresult result, const char* file, int32_t line); + #define CUDA_CHECK(e) ::facebook::velox::wave::cudaCheck(e, __FILE__, __LINE__) +#define CU_CHECK(e) ::facebook::velox::wave::cuCheck(e, __FILE__, __LINE__) + #ifndef CUDA_CHECK_FATAL #define CUDA_CHECK_FATAL(e) \ ::facebook::velox::wave::cudaCheckFatal(e, __FILE__, __LINE__) #endif +// Gets device and context for Driver API. Initializes on first use. +void getDeviceAndContext(CUdevice& device, CUcontext& context); + template __host__ __device__ constexpr inline T roundUp(T value, U factor) { return (value + (factor - 1)) / factor * factor; @@ -91,7 +99,7 @@ inline uint32_t __device__ deviceScale32(uint32_t n, uint32_t scale) { } struct StreamImpl { - cudaStream_t stream; + cudaStream_t stream{}; }; bool registerKernel(const char* name, const void* func); diff --git a/velox/experimental/wave/common/KernelCache.cpp b/velox/experimental/wave/common/KernelCache.cpp new file mode 100644 index 0000000000000..c81a2cbf88341 --- /dev/null +++ b/velox/experimental/wave/common/KernelCache.cpp @@ -0,0 +1,140 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/caching/CachedFactory.h" +#include "velox/experimental/wave/common/Cuda.h" + +#include +#include + +namespace facebook::velox::wave { + +using ModulePtr = std::shared_ptr; + +static folly::CPUThreadPoolExecutor* compilerExecutor() { + static std::unique_ptr pool = + std::make_unique(10); + return pool.get(); +} + +namespace { +class FutureCompiledModule : public CompiledModule { + public: + explicit FutureCompiledModule(folly::Future future) + : future_(std::move(future)) {} + + void launch( + int32_t kernelIdx, + int32_t numBlocks, + int32_t numThreads, + int32_t shared, + Stream* stream, + void** args) override { + ensureReady(); + module_->launch(kernelIdx, numBlocks, numThreads, shared, stream, args); + } + + KernelInfo info(int32_t kernelIdx) override { + ensureReady(); + return module_->info(kernelIdx); + } + + private: + void ensureReady() { + std::lock_guard l(mutex_); + // 'module_' is a shared_ptr, so read is not atomic. Read inside the mutex. + if (module_) { + return; + } + module_ = std::move(future_).get(); + } + + private: + std::mutex mutex_; + ModulePtr module_; + folly::Future future_; +}; + +using KernelPtr = CachedPtr; + +class AsyncCompiledKernel : public CompiledKernel { + public: + explicit AsyncCompiledKernel(KernelPtr ptr) : ptr_(std::move(ptr)) {} + + void launch( + int32_t kernelIdx, + int32_t numBlocks, + int32_t numThreads, + int32_t shared, + Stream* stream, + void** args) override { + (*ptr_)->launch(kernelIdx, numBlocks, numThreads, shared, stream, args); + } + + private: + KernelPtr ptr_; +}; + +class KernelGenerator { + public: + std::unique_ptr operator()( + const std::string, + const KernelGenFunc* gen) { + using ModulePromise = folly::Promise; + struct PromiseHolder { + ModulePromise promise; + }; + auto holder = std::make_shared(); + + auto future = holder->promise.getFuture(); + auto* device = currentDevice(); + compilerExecutor()->add([genCopy = *gen, holder, device]() { + setDevice(device); + auto spec = genCopy(); + auto module = CompiledModule::create(spec); + holder->promise.setValue(module); + }); + ModulePtr result = + std::make_shared(std::move(future)); + return std::make_unique(result); + } +}; + +using KernelCache = + CachedFactory; + +std::unique_ptr makeCache() { + auto generator = std::make_unique(); + return std::make_unique( + std::make_unique>(1000), + std::move(generator)); +} + +KernelCache& kernelCache() { + static std::unique_ptr cache = makeCache(); + return *cache; +} +} // namespace + +// static +std::unique_ptr CompiledKernel::getKernel( + const std::string& key, + KernelGenFunc gen) { + auto ptr = kernelCache().generate(key, &gen); + return std::make_unique(std::move(ptr)); +} + +} // namespace facebook::velox::wave diff --git a/velox/experimental/wave/common/tests/CMakeLists.txt b/velox/experimental/wave/common/tests/CMakeLists.txt index 8cb290f4baec9..4fffaf40ada08 100644 --- a/velox/experimental/wave/common/tests/CMakeLists.txt +++ b/velox/experimental/wave/common/tests/CMakeLists.txt @@ -17,6 +17,7 @@ add_executable( GpuArenaTest.cpp CudaTest.cpp CudaTest.cu + CompileTest.cu BlockTest.cpp BlockTest.cu HashTableTest.cpp @@ -34,4 +35,5 @@ target_link_libraries( GTest::gtest_main gflags::gflags glog::glog - Folly::folly) + Folly::folly + CUDA::nvrtc) diff --git a/velox/experimental/wave/common/tests/CompileTest.cu b/velox/experimental/wave/common/tests/CompileTest.cu new file mode 100644 index 0000000000000..16ec6e52f8337 --- /dev/null +++ b/velox/experimental/wave/common/tests/CompileTest.cu @@ -0,0 +1,139 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include "velox/experimental/wave/common/Buffer.h" +#include "velox/experimental/wave/common/CudaUtil.cuh" +#include "velox/experimental/wave/common/Exception.h" +#include "velox/experimental/wave/common/GpuArena.h" +#include "velox/experimental/wave/common/tests/BlockTest.h" + +#include + +namespace facebook::velox::wave { + +void testCuCheck(CUresult result) { + if (result != CUDA_SUCCESS) { + const char* str; + cuGetErrorString(result, &str); + waveError(std::string("Cuda error: ") + str); + } +} + +class CompileTest : public testing::Test { + protected: + void SetUp() override { + device_ = getDevice(); + setDevice(device_); + allocator_ = getAllocator(device_); + arena_ = std::make_unique(1 << 28, allocator_); + streams_.push_back(std::make_unique()); + } + + Device* device_; + GpuAllocator* allocator_; + std::unique_ptr arena_; + std::vector> streams_; +}; + +struct KernelParams { + int32_t* array; + int32_t size; +}; + +const char* kernelText = + "using int32_t = int; //#include \n" + "namespace facebook::velox::wave {\n" + " struct KernelParams {\n" + " int32_t* array;\n" + " int32_t size;\n" + " };\n" + "\n" + " void __global__ add1(KernelParams params) {\n" + " for (auto i = threadIdx.x; i < params.size; i += blockDim.x) {\n" + " ++params.array[i];\n" + " }\n" + " }\n" + "\n" + " void __global__ add2(KernelParams params) {\n" + " for (auto i = threadIdx.x; i < params.size; i += blockDim.x) {\n" + " params.array[i] += 2;\n" + " }\n" + " }\n" + "} // namespace\n"; + +void __global__ add3(KernelParams params) { + for (auto i = threadIdx.x; i < params.size; i += blockDim.x) { + params.array[i] += 3; + } +} + +TEST_F(CompileTest, module) { + KernelSpec spec = KernelSpec{ + kernelText, + {"facebook::velox::wave::add1", "facebook::velox::wave::add2"}, + "/tmp/add1.cu"}; + auto module = CompiledModule::create(spec); + int32_t* ptr; + testCuCheck(cuMemAllocManaged( + reinterpret_cast(&ptr), + 1000 * sizeof(int32_t), + CU_MEM_ATTACH_GLOBAL)); + KernelParams record{ptr, 1000}; + memset(ptr, 0, 1000 * sizeof(int32_t)); + void* recordPtr = &record; + auto impl = std::make_unique(); + testCuCheck(cuStreamCreate((CUstream*)&impl->stream, CU_STREAM_DEFAULT)); + auto stream = std::make_unique(std::move(impl)); + module->launch(0, 1, 256, 0, stream.get(), &recordPtr); + testCuCheck(cuStreamSynchronize((CUstream)stream->stream()->stream)); + EXPECT_EQ(1, ptr[0]); + auto info = module->info(0); + EXPECT_EQ(1024, info.maxThreadsPerBlock); + + // See if runtime API kernel works on driver API stream. + add3<<<1, 256, 0, (cudaStream_t)stream->stream()->stream>>>(record); + CUDA_CHECK(cudaGetLastError()); + testCuCheck(cuStreamSynchronize((CUstream)stream->stream()->stream)); + EXPECT_EQ(4, ptr[0]); + + auto stream2 = std::make_unique(); + module->launch(1, 1, 256, 0, stream2.get(), &recordPtr); + stream2->wait(); + EXPECT_EQ(6, ptr[0]); +} + +TEST_F(CompileTest, cache) { + KernelSpec spec = KernelSpec{ + kernelText, + {"facebook::velox::wave::add1", "facebook::velox::wave::add2"}, + "/tmp/add1.cu"}; + auto kernel = + CompiledKernel::getKernel("add1", [&]() -> KernelSpec { return spec; }); + auto buffer = arena_->allocate(1000); + memset(buffer->as(), 0, sizeof(int32_t) * 1000); + KernelParams record{buffer->as(), 1000}; + void* recordPtr = &record; + auto stream = std::make_unique(); + kernel->launch(1, 1, 256, 0, stream.get(), &recordPtr); + stream->wait(); + EXPECT_EQ(2, buffer->as()[0]); +} + +} // namespace facebook::velox::wave diff --git a/velox/experimental/wave/common/tests/CudaTest.cpp b/velox/experimental/wave/common/tests/CudaTest.cpp index 9f924d962f8e7..d3788435a7f81 100644 --- a/velox/experimental/wave/common/tests/CudaTest.cpp +++ b/velox/experimental/wave/common/tests/CudaTest.cpp @@ -578,6 +578,7 @@ class RoundtripThread { kAddBranch, kAddFuncStore, kAddSwitch, + kAddMultiStream, kAddRandom, kAddRandomEmptyWarps, kAddRandomEmptyThreads, @@ -692,6 +693,26 @@ class RoundtripThread { stats.numAdds += op.param1 * op.param2 * 256; break; + case OpCode::kAddMultiStream: + VELOX_CHECK_LE(op.param1, kNumKB); + if (stats.isCpu) { + addOneCpu(op.param1 * 256, op.param2); + } else { + auto [numStreams, numPerStream] = ensureStreams(op.param1); + for (int32_t i = 0; i < numStreams; i++) { + streams_[i]->addOne( + deviceBuffer_->as() + i * numPerStream, + op.param1 * 256 / numStreams, + op.param2, + op.param3); + events_[i]->record(*streams_[i]); + events_[i]->wait(*stream_); + } + stream_->wait(); + } + stats.numAdds += op.param1 * op.param2 * 256; + break; + case OpCode::kAddFuncStore: VELOX_CHECK_LE(op.param1, kNumKB); if (stats.isCpu) { @@ -884,6 +905,9 @@ class RoundtripThread { } else if (str[position] == 'w') { op.opCode = OpCode::kAddSwitch; ++position; + } else if (str[position] == 'm') { + op.opCode = OpCode::kAddMultiStream; + ++position; } op.param1 = parseInt(str, position, 1); op.param2 = parseInt(str, position, 1); @@ -987,6 +1011,21 @@ class RoundtripThread { return result; } + // Returns number of streams, number of items per stream for multi-stream + // execution over 'kb' KB of input. Ensures enough streams and events in + // 'streams_' and 'events_'. Take min 4KB per TB. + std::pair ensureStreams(int32_t kb) { + VELOX_CHECK_EQ(kb & 3, 0); + int32_t numStreams = kb / 4; + if (streams_.size() < numStreams) { + for (auto i = streams_.size(); i < numStreams; ++i) { + streams_.push_back(std::make_unique()); + events_.push_back(std::make_unique()); + } + } + return {numStreams, kb / numStreams / 256}; + } + ArenaSet* const arenas_; Device* device_{nullptr}; WaveBufferPtr deviceBuffer_; @@ -996,6 +1035,10 @@ class RoundtripThread { std::unique_ptr hostInts_; std::unique_ptr stream_; std::unique_ptr event_; + + std::vector> streams_; + std::vector> events_; + int32_t serial_{0}; static inline std::atomic serialCounter_{0}; }; diff --git a/velox/experimental/wave/dwio/decode/tests/CMakeLists.txt b/velox/experimental/wave/dwio/decode/tests/CMakeLists.txt index 6adeeb6268aad..bba31ba3174d3 100644 --- a/velox/experimental/wave/dwio/decode/tests/CMakeLists.txt +++ b/velox/experimental/wave/dwio/decode/tests/CMakeLists.txt @@ -27,4 +27,5 @@ target_link_libraries( GTest::gtest_main gflags::gflags glog::glog - Folly::folly) + Folly::folly + CUDA::nvrtc) diff --git a/velox/experimental/wave/exec/tests/CMakeLists.txt b/velox/experimental/wave/exec/tests/CMakeLists.txt index ae58682051025..92382cc2a1e0a 100644 --- a/velox/experimental/wave/exec/tests/CMakeLists.txt +++ b/velox/experimental/wave/exec/tests/CMakeLists.txt @@ -55,7 +55,8 @@ target_link_libraries( gflags::gflags glog::glog fmt::fmt - ${FILESYSTEM}) + ${FILESYSTEM} + CUDA::nvrtc) add_test(velox_wave_exec_test velox_wave_exec_test) @@ -102,5 +103,6 @@ if(${VELOX_ENABLE_BENCHMARKS}) gflags::gflags glog::glog fmt::fmt - ${FILESYSTEM}) + ${FILESYSTEM} + CUDA::nvrtc) endif()