From f55e7be4878190f9aacbe571eee9e9a68466a123 Mon Sep 17 00:00:00 2001 From: jiyang1011 Date: Mon, 29 Apr 2024 01:18:02 -0700 Subject: [PATCH] rebase cutlass-fork intel gpu supporting --- build.sh | 10 +- examples/cute/tutorial/pvc_sycl/pvc_sycl.cpp | 202 +++++++++---------- include/cute/util/sycl_vec.hpp | 2 +- include/cutlass/cutlass.h | 29 --- 4 files changed, 100 insertions(+), 143 deletions(-) diff --git a/build.sh b/build.sh index 971246a623..234fcfd6fe 100644 --- a/build.sh +++ b/build.sh @@ -1,12 +1,16 @@ sycl_compiler_path=/opt/cutlass_compiler/ target=./examples/cute/tutorial/pvc_sycl cuda_path=/usr/local/cuda-12.3/ +mkl_path=/opt/intel/oneapi/mkl/2024.1 rm -rf $target -export CPATH=$sycl_compiler_path:$sycl_compiler_path/include/:$sycl_compiler_path/include/sycl/:/opt/intel/oneapi/mkl/2024.1/include/ +export ZE_AFFINITY_MASK=0 +export CPATH=$sycl_compiler_path:$sycl_compiler_path/include/:$sycl_compiler_path/include/sycl/:$mkl_path/include/ export LIBRARY_PATH=/opt/intel/oneapi/mkl/2024.1/lib/ -export LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/2024.1/lib/:${sycl_compiler_path}/lib/ +export LD_LIBRARY_PATH=$mkl_path/lib/:${sycl_compiler_path}/lib/ export IGC_EnableVISANoSchedule=1 export IGC_ShaderDumpEnable=1 export IGC_DumpToCustomDir=./mm_dumps_prefetch_coop export IGC_VATemp=1 -cmake .. -G Ninja -DCMAKE_CUDA_HOST_COMPILER=${sycl_compiler_path}/bin/clang++ -DCMAKE_CUDA_COMPILER=$cuda_path/bin/nvcc -DCUTLASS_ENABLE_SYCL=ON -DDPCPP_SYCL_TARGET=intel_gpu_pvc -DCMAKE_CXX_COMPILER=${sycl_compiler_path}/bin/clang++ -DCMAKE_CXX_FLAGS=" -DITEM_SIZE_X=4 -DITEM_SIZE_Y=32 -DSG_SIZE_X=64 -DSG_SIZE_Y=ITEM_SIZE_Y -DWG_SIZE_X=256 -DWG_SIZE_Y=256 -DKK=2 -DPREFETCH_DEFAULT -lmkl_intel_lp64 -lmkl_sequential -lmkl_core" && ninja -v $target && ONEAPI_DEVICE_SELECTOR=level_zero:gpu $target +cmake .. -G Ninja -DCMAKE_CUDA_HOST_COMPILER=${sycl_compiler_path}/bin/clang++ -DCMAKE_CUDA_COMPILER=$cuda_path/bin/nvcc \ +-DCUTLASS_ENABLE_SYCL=ON -DDPCPP_SYCL_TARGET=intel_gpu_pvc -DCMAKE_CXX_COMPILER=${sycl_compiler_path}/bin/clang++ \ +-DCMAKE_CXX_FLAGS=" -lmkl_intel_lp64 -lmkl_sequential -lmkl_core -DPREFETCH_DEFAULT" && ninja -v $target && ONEAPI_DEVICE_SELECTOR=level_zero:gpu $target diff --git a/examples/cute/tutorial/pvc_sycl/pvc_sycl.cpp b/examples/cute/tutorial/pvc_sycl/pvc_sycl.cpp index 3c539c6c0c..189c8d4746 100644 --- a/examples/cute/tutorial/pvc_sycl/pvc_sycl.cpp +++ b/examples/cute/tutorial/pvc_sycl/pvc_sycl.cpp @@ -40,34 +40,14 @@ dtype_acc threshold = 0.01f; #define split_barrier_wait() __builtin_IB_work_group_barrier_wait(0) template -static void fill_matrix(std::vector &M, size_t numRows, size_t numCols) { - if (identityData) { - std::generate(std::begin(M), std::end(M), [&] { return 1.0_bf16; }); - } else if (fixedData) { - for (size_t r = 0; r < numRows; r++) { - for (size_t c = 0; c < numCols; c++) { - M[r * numCols + c] = bfloat16_t(float(r + c)); - } - } - } else { - std::random_device dev; - std::mt19937 rng(dev()); - std::uniform_real_distribution dist(-1.0, 1.0); - std::generate(std::begin(M), std::end(M), - [&] { return bfloat16_t(dist(rng)); }); - } -} - -template -static void fill_matrix_B(T *M, size_t numRows, size_t numCols) { +static void init_matrix(T *M, size_t numRows, size_t numCols) { + std::random_device dev; + std::mt19937 rng(dev()); + std::uniform_real_distribution dist(-1.0, 1.0); for (size_t r = 0; r < numRows; r++) { for (size_t c = 0; c < numCols; c++) { - M[r * numCols + c] = bfloat16_t(0.0f); + M[r * numCols + c] = bfloat16_t(dist(rng)); } - }; - - for (size_t r = 0; r < numRows; r++) { - M[r * numCols + r] = bfloat16_t(1.0f); } } @@ -96,11 +76,11 @@ void check_results(size_t M, size_t N, const T *C, const T *C_ref) { err = std::max(localErr, err); if (localErr >= threshold) { error_cnt++; - // std::cerr << "Error at m = " << m << ", n = " << n << ": (local - // error" - // << localErr << "): Wanted " << C_ref[index] << ", got " - // << C[index] << std::endl; - // return; +#if 0 + std::cerr << "Error at m = " << m << ", n = " << n << ": (local error" + << localErr << "): Wanted " << C_ref[index] << ", got " + << C[index] << std::endl; +#endif } } } @@ -149,19 +129,17 @@ void cute_gemm(size_t M, size_t K, size_t N) { (dtype_b *)syclcompat::malloc_host(sizeof(dtype_b) * N * K); vnni_matrix(Bvnni_host, B_host, K, N, 2); - Tensor tAr = make_tensor(Shape<_8, Int>{}); - Tensor tBr = make_tensor(Shape<_8, Int>{}); - Tensor tCr = - make_tensor(Shape<_8, Int, Int>{}); + queue.memcpy(A_dev, A_host, sizeof(dtype_a) * M * K).wait(); + queue.memcpy(B_dev, Bvnni_host, sizeof(dtype_b) * N * K).wait(); + queue.memcpy(C_dev, C_host, sizeof(dtype_c) * M * N).wait(); - auto A_copy = make_xe_2d_copy( - make_tensor(make_gmem_ptr(A), make_shape(M, K))); - auto B_copy = make_xe_2d_copy( - make_tensor(make_gmem_ptr(B), make_shape(K, N))); - auto C_copy = make_xe_2d_copy( - make_tensor(make_gmem_ptr(C), make_shape(M, N))); - // TODO: - decide on how to deal with vector types - // - create layouts with tiling/partitioning + printf("Computing reference...\n"); + dtype_acc *C_ref_host = + (dtype_acc *)syclcompat::malloc_host(sizeof(dtype_acc) * M * N); + + get_gemm_gold( + M, N, K, mem_layout::row_major, mem_layout::row_major, (dtype_a *)A_host, + (dtype_b *)B_host, (dtype_c *)C_ref_host); printf("Running gemm tests, MKN: (%d, %d, %d)...\n", M, K, N); @@ -195,93 +173,97 @@ void cute_gemm(size_t M, size_t K, size_t N) { for (uint32_t test = 0; test < total_iterations; test++) { sycl::event ev; ev = queue.submit([&](sycl::handler &cgh) { - cgh.parallel_for(nd_range, [=](sycl::nd_item<2> - id) [[sycl::reqd_sub_group_size(16)]] { - const int M = id.get_global_range(0) * ITEM_SIZE_Y; - const int N = id.get_global_range(1) * ITEM_SIZE_X; - const int m = id.get_group(0) * WG_SIZE_Y + - (get_sub_group_id() / SGS_PER_WG_X) * SG_SIZE_Y; - const int n = id.get_group(1) * WG_SIZE_X + - (get_sub_group_id() % SGS_PER_WG_X) * SG_SIZE_X; - - Tensor tAr = make_tensor(Shape, Int<1>>{}); - Tensor tBr = make_tensor(Shape, Int>{}); - Tensor tCr = make_tensor(Shape<_8, Int, Int>{}); - - auto A_copy = - make_xe_2d_A_copy(make_tensor(make_gmem_ptr(A), make_shape(M, K))); - auto B_copy = - make_xe_2d_B_copy(make_tensor(make_gmem_ptr(B), make_shape(K, N))); - auto C_copy = - make_xe_2d_copy(make_tensor(make_gmem_ptr(C), make_shape(M, N))); - // TODO: - decide on how to deal with vector types - // - create layouts with tiling/partitioning - - Tensor tAi = make_tensor( - make_inttuple_iter(m, 0), - make_layout(make_shape(_1{}, _1{}, K), - make_stride(_1{}, MM * tM * E<0>{}, E<1>{}))); - Tensor tBi = - make_tensor(make_inttuple_iter(0, n), - make_layout(make_shape(_1{}, K, Int{}), - make_stride(_1{}, E<0>{}, tN * E<1>{}))); - Tensor tCi = make_tensor( - make_inttuple_iter(m, n), - make_layout(Shape<_1, Int, Int>{}, - make_stride(_1{}, tM * E<0>{}, tN * E<1>{}))); - TiledMMA, - Layout>> - tiled_mma; - - int prefetch_k = 0; + cgh.parallel_for( + nd_range, [=](sycl::nd_item<2> id) [[sycl::reqd_sub_group_size( + subgroup_size)]] { + const int m = id.get_group(1) * wg_tile_m + + (get_sub_group_id() / sg_per_wg_n) * sg_tile_m; + const int n = id.get_group(0) * wg_tile_n + + (get_sub_group_id() % sg_per_wg_n) * sg_tile_n; + + Tensor tAr = + make_tensor(Shape, Int<1>>{}); + Tensor tBr = + make_tensor(Shape, Int>{}); + Tensor tCr = + make_tensor(Shape, Int, Int>{}); + + auto A_copy = make_xe_2d_copy( + make_tensor(make_gmem_ptr(A_dev), make_shape(M, K))); + auto B_copy = make_xe_2d_copy( + make_tensor(make_gmem_ptr(B_dev), make_shape(K, N))); + auto C_copy = make_xe_2d_copy( + make_tensor(make_gmem_ptr(C_dev), make_shape(M, N))); + // TODO: - decide on how to deal with vector types + // - create layouts with tiling/partitioning + + Tensor tAi = make_tensor( + make_inttuple_iter(m, 0), + make_layout(make_shape(_1{}, _1{}, K), + make_stride(_1{}, MM * tM * E<0>{}, E<1>{}))); + Tensor tBi = make_tensor( + make_inttuple_iter(0, n), + make_layout(make_shape(_1{}, K, Int{}), + make_stride(_1{}, E<0>{}, tN * E<1>{}))); + Tensor tCi = make_tensor( + make_inttuple_iter(m, n), + make_layout(Shape<_1, Int, Int>{}, + make_stride(_1{}, tM * E<0>{}, tN * E<1>{}))); + TiledMMA, + Layout>> + tiled_mma; + + uint32_t prefetch_k = 0; #ifdef PREFETCH_DEFAULT - for (int p = 0; p < PREFETCH_DISTANCE; p++) { + for (uint32_t p = 0; p < PREFETCH_DISTANCE; p++) { #ifdef B_VNNI - HELPER_NAME(btile_block_prefetch_vnni, 4, 4) - ((ushort *)B, tN, K, N, prefetch_k, n); + HELPER_NAME(btile_block_prefetch_vnni, 4, 4) + ((ushort *)B_dev, tN, K, N, prefetch_k, n); #else HELPER_NAME(btile_block_prefetch_rowmajor, 4, 4) ((ushort *)B_dev, tN, K, N, prefetch_k, n); #endif - HELPER_NAME(atile_block_prefetch_rowmajor, 4, 4) - ((ushort *)A, tM, M, K, m, prefetch_k); - prefetch_k += tK * KK; - } + HELPER_NAME(atile_block_prefetch_rowmajor, 4, 4) + ((ushort *)A_dev, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + } #endif - for (int k = 0; k < K; k += tK * KK) { - copy(A_copy, tAi(_, _, k), tAr); - copy(B_copy, tBi(_, k / 2, _), tBr); + for (int k = 0; k < K + tK * KK - 1; k += tK * KK) { + copy(A_copy, tAi(_, _, k), tAr); + copy(B_copy, tBi(_, k / KK, _), tBr); #ifdef PREFETCH_DEFAULT - for (int p = 0; p < PREFETCH_DISTANCE; p++) { + for (uint32_t p = 0; p < PREFETCH_DISTANCE; p++) { #ifdef B_VNNI - HELPER_NAME(btile_block_prefetch_vnni, 4, 4) - ((ushort *)B, tN, K, N, prefetch_k, n); + HELPER_NAME(btile_block_prefetch_vnni, 4, 4) + ((ushort *)B_dev, tN, K, N, prefetch_k, n); #else HELPER_NAME(btile_block_prefetch_rowmajor, 4, 4) ((ushort *)B_dev, tN, K, N, prefetch_k, n); #endif - HELPER_NAME(atile_block_prefetch_rowmajor, 4, 4) - ((ushort *)A, tM, M, K, m, prefetch_k); - prefetch_k += tK * KK; - } + HELPER_NAME(atile_block_prefetch_rowmajor, 4, 4) + ((ushort *)A_dev, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + } #endif - auto tAr_view = make_tensor(static_cast(tAr).data(), - Shape<_8, Int, Int>{}); - auto tBr_view = make_tensor(static_cast(tBr).data(), - Shape<_16, Int, Int>{}); - for (int kl = 0; kl < KK; kl++) { - gemm(tiled_mma, tAr_view(_, _, kl), tBr_view(_, kl, _), tCr); - } - } - - copy(C_copy, tCr, tCi); - }); + auto tAr_view = + make_tensor(static_cast(tAr).data(), + Shape, Int, Int>{}); + auto tBr_view = + make_tensor(static_cast(tBr).data(), + Shape, Int, Int>{}); + for (uint32_t kl = 0; kl < KK; kl++) { + gemm(tiled_mma, tAr_view(_, _, kl), tBr_view(_, kl, _), tCr); + } + } + + copy(C_copy, tCr, tCi); + }); }); ev.wait_and_throw(); - event_times[test] = time_event(ev) / 1e6; + event_times[test] = time_event(ev) / 1e9; // seconds } double average_event_time = 0.f; diff --git a/include/cute/util/sycl_vec.hpp b/include/cute/util/sycl_vec.hpp index 25df0be72a..6c6c3837a3 100644 --- a/include/cute/util/sycl_vec.hpp +++ b/include/cute/util/sycl_vec.hpp @@ -51,4 +51,4 @@ typedef ushort __attribute__((ext_vector_type(16))) ushort16; typedef ushort __attribute__((ext_vector_type(32))) ushort32; typedef ushort __attribute__((ext_vector_type(64))) ushort64; typedef uint __attribute__((ext_vector_type(32))) uint32; -typedef int __attribute__((ext_vector_type(16))) int16; +typedef int __attribute__((ext_vector_type(16))) int16; \ No newline at end of file diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index 0f0bf5fa2c..32851afdb5 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -121,7 +121,6 @@ CUTLASS_HOST_DEVICE uint ThreadIdxX() { return syclcompat::local_id::x(); #else return 0; - return 0; #endif } @@ -132,7 +131,6 @@ CUTLASS_HOST_DEVICE uint ThreadIdxY() { return syclcompat::local_id::y(); #else return 0; - return 0; #endif } @@ -143,7 +141,6 @@ CUTLASS_HOST_DEVICE uint ThreadIdxZ() { return syclcompat::local_id::z(); #else return 0; - return 0; #endif } @@ -154,7 +151,6 @@ CUTLASS_HOST_DEVICE uint BlockIdxX() { return syclcompat::work_group_id::x(); #else return 0; - return 0; #endif } @@ -165,7 +161,6 @@ CUTLASS_HOST_DEVICE uint BlockIdxY() { return syclcompat::work_group_id::y(); #else return 0; - return 0; #endif } @@ -176,7 +171,6 @@ CUTLASS_HOST_DEVICE uint BlockIdxZ() { return syclcompat::work_group_id::z(); #else return 0; - return 0; #endif } @@ -187,7 +181,6 @@ CUTLASS_HOST_DEVICE uint BlockDimX() { return syclcompat::work_group_range::x(); #else return 0; - return 0; #endif } @@ -198,7 +191,6 @@ CUTLASS_HOST_DEVICE uint BlockDimY() { return syclcompat::work_group_range::y(); #else return 0; - return 0; #endif } @@ -209,7 +201,6 @@ CUTLASS_HOST_DEVICE uint BlockDimZ() { return syclcompat::work_group_range::z(); #else return 0; - return 0; #endif } @@ -220,7 +211,6 @@ CUTLASS_HOST_DEVICE uint GridDimX() { return syclcompat::global_range::x(); #else return 0; - return 0; #endif } @@ -231,7 +221,6 @@ CUTLASS_HOST_DEVICE uint GridDimY() { return syclcompat::global_range::y(); #else return 0; - return 0; #endif } @@ -242,7 +231,6 @@ CUTLASS_HOST_DEVICE uint GridDimZ() { return syclcompat::global_range::z(); #else return 0; - return 0; #endif } @@ -264,7 +252,6 @@ CUTLASS_DEVICE int syncthreads_and(int cond) { assert(false); #else return 0; - return 0; #endif } @@ -296,10 +283,8 @@ uint byte_perm(uint x, uint y, uint s) { // TODO: Add SYCL equivalent function assert(false); return 0; - return 0; #else return 0; - return 0; #endif } @@ -313,10 +298,8 @@ uint shfl_up_sync(const unsigned mask, const uint var, const int delta, const in // TODO: Add SYCL equivalent function assert(false); return 0; - return 0; #else return 0; - return 0; #endif } @@ -328,10 +311,8 @@ uint shfl_down_sync(const unsigned mask, const uint var, const int delta, const // TODO: Add SYCL equivalent function assert(false); return 0; - return 0; #else return 0; - return 0; #endif } @@ -343,10 +324,8 @@ uint shfl_sync(const unsigned mask, const uint var, const int delta, const int w // TODO: Add SYCL equivalent function assert(false); return 0; - return 0; #else return 0; - return 0; #endif } @@ -360,10 +339,8 @@ CUTLASS_DEVICE T hfma2(const T a, const T b, const T c) { // TODO: Add SYCL equivalent function assert(false); return T(0); - return T(0); #else return T(0); - return T(0); #endif } @@ -377,9 +354,6 @@ CUTLASS_DEVICE int atomicAdd(int *address, int val) { #else return 0; #endif -#else - return 0; -#endif } CUTLASS_DEVICE int atomicCAS(int *address, int compare, int val) { @@ -389,9 +363,6 @@ CUTLASS_DEVICE int atomicCAS(int *address, int compare, int val) { #else return 0; #endif -#else - return 0; -#endif } #endif