From 6d1586122c8a0073890b9334c2f44680dcdabbe9 Mon Sep 17 00:00:00 2001 From: Roland Schulz Date: Wed, 17 Apr 2024 14:45:39 -0700 Subject: [PATCH] use cute::bfloat16_t --- examples/cute/tutorial/pvc_sycl.cpp | 20 ++++++++------------ include/cutlass/bfloat16.h | 4 ++++ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/cute/tutorial/pvc_sycl.cpp b/examples/cute/tutorial/pvc_sycl.cpp index 3bff9751cf..96354a9df7 100644 --- a/examples/cute/tutorial/pvc_sycl.cpp +++ b/examples/cute/tutorial/pvc_sycl.cpp @@ -17,8 +17,7 @@ #include using test_clock = std::chrono::high_resolution_clock; - -using sycl::ext::oneapi::bfloat16; +using namespace cute; bool identityData = false; bool fixedData = false; @@ -48,7 +47,7 @@ static void fill_matrix(std::vector &M, size_t numRows, size_t numCols) if (identityData) { std::generate(std::begin(M), std::end(M), [&] - { return 1.0f; }); + { return 1.0_bf16; }); } else if (fixedData) { @@ -56,7 +55,7 @@ static void fill_matrix(std::vector &M, size_t numRows, size_t numCols) { for (size_t c = 0; c < numCols; c++) { - M[r * numCols + c] = static_cast(r + c); + M[r * numCols + c] = bfloat16_t(float(r + c)); } } } @@ -66,7 +65,7 @@ static void fill_matrix(std::vector &M, size_t numRows, size_t numCols) std::mt19937 rng(dev()); std::uniform_real_distribution dist(-1.0, 1.0); std::generate(std::begin(M), std::end(M), [&] - { return dist(rng); }); + { return bfloat16_t(dist(rng)); }); } } @@ -153,7 +152,7 @@ inline size_t time_event(sycl::event &e) template static void go_dpas_blockread_vnni_tiled( sycl::queue queue, - std::vector &c_vec, sycl::buffer a, sycl::buffer b, + std::vector &c_vec, sycl::buffer a, sycl::buffer b, size_t M, size_t N, size_t K, const std::vector &C_ref) { @@ -193,9 +192,6 @@ static void go_dpas_blockread_vnni_tiled( auto B = accB.get_multi_ptr().get(); auto C = accC.get_multi_ptr().get(); - - using namespace cute; - Tensor tAr = make_tensor(Shape<_8, Int>{}); Tensor tBr = make_tensor(Shape<_8, Int>{}); Tensor tCr = make_tensor(Shape<_8, Int, Int>{}); @@ -256,9 +252,9 @@ int main(int argc, char **argv) const auto N = matrixSize; const auto K = matrixSize; - std::vector A_vec(M * K); - std::vector B_vec(K * N); - std::vector Bvnni_vec(K * N); + std::vector A_vec(M * K); + std::vector B_vec(K * N); + std::vector Bvnni_vec(K * N); std::vector C_vec(M * N); std::vector C_ref(M * N); diff --git a/include/cutlass/bfloat16.h b/include/cutlass/bfloat16.h index 75cadbfa43..081a0960ef 100644 --- a/include/cutlass/bfloat16.h +++ b/include/cutlass/bfloat16.h @@ -118,6 +118,10 @@ struct alignas(2) bfloat16_t { asm("cvt.rn.bf16.f32 %0, %1;\n" : "=h"(storage) : "f"(x)); + #elif defined(CUTLASS_ENABLE_SYCL) + + storage = sycl::ext::oneapi::detail::bfloat16ToBits(sycl::ext::oneapi::bfloat16(x)); + #else uint32_t bits;