diff --git a/velox/experimental/wave/common/tests/CudaTest.cpp b/velox/experimental/wave/common/tests/CudaTest.cpp index c141f8dc9f9c..71922a4d8ca4 100644 --- a/velox/experimental/wave/common/tests/CudaTest.cpp +++ b/velox/experimental/wave/common/tests/CudaTest.cpp @@ -14,8 +14,9 @@ * limitations under the License. */ -#include +#include "velox/experimental/wave/common/tests/CudaTest.h" +#include // @manual #include #include #include @@ -31,7 +32,8 @@ #include "velox/common/time/Timer.h" #include "velox/experimental/wave/common/GpuArena.h" #include "velox/experimental/wave/common/tests/BlockTest.h" -#include "velox/experimental/wave/common/tests/CudaTest.h" + +#include DEFINE_int32(num_streams, 0, "Number of paralll streams"); DEFINE_int32(op_size, 0, "Size of invoke kernel (ints read and written)"); @@ -900,5 +902,9 @@ TEST_F(CudaTest, reduceMatrix) { int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); folly::init(&argc, &argv); + if (int device; cudaGetDevice(&device) != cudaSuccess) { + LOG(WARNING) << "No CUDA detected, skipping all tests"; + return 0; + } return RUN_ALL_TESTS(); } diff --git a/velox/experimental/wave/common/tests/IdMapTest.cu b/velox/experimental/wave/common/tests/IdMapTest.cu index e1289ba783e7..cc8b322400d7 100644 --- a/velox/experimental/wave/common/tests/IdMapTest.cu +++ b/velox/experimental/wave/common/tests/IdMapTest.cu @@ -173,5 +173,9 @@ TEST(IdMapTest, overflowNoEmptyMarker) { int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); folly::Init follyInit(&argc, &argv); + if (int device; cudaGetDevice(&device) != cudaSuccess) { + LOG(WARNING) << "No CUDA detected, skipping all tests"; + return 0; + } return RUN_ALL_TESTS(); } diff --git a/velox/experimental/wave/exec/tests/AggregationTest.cpp b/velox/experimental/wave/exec/tests/AggregationTest.cpp index 596a6dd64972..499cc36bad55 100644 --- a/velox/experimental/wave/exec/tests/AggregationTest.cpp +++ b/velox/experimental/wave/exec/tests/AggregationTest.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include // @manual #include #include #include @@ -46,6 +47,12 @@ class AggregationTest : public OperatorTestBase { OperatorTestBase::SetUpTestCase(); wave::registerWave(); } + + void SetUp() override { + if (int device; cudaGetDevice(&device) != cudaSuccess) { + GTEST_SKIP() << "No CUDA detected, skipping all tests"; + } + } }; TEST_F(AggregationTest, singleKeySingleAggregate) { diff --git a/velox/experimental/wave/exec/tests/FilterProjectTest.cpp b/velox/experimental/wave/exec/tests/FilterProjectTest.cpp index b211898e35ae..3b7fdb3a94ce 100644 --- a/velox/experimental/wave/exec/tests/FilterProjectTest.cpp +++ b/velox/experimental/wave/exec/tests/FilterProjectTest.cpp @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include // @manual #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" @@ -29,6 +30,9 @@ using facebook::velox::test::BatchMaker; class FilterProjectTest : public OperatorTestBase { protected: void SetUp() override { + if (int device; cudaGetDevice(&device) != cudaSuccess) { + GTEST_SKIP() << "No CUDA detected, skipping all tests"; + } wave::registerWave(); }