Skip to content

Commit

Permalink
Use cute::bfloat16_t
Browse files Browse the repository at this point in the history
  • Loading branch information
rolandschulz authored and taozha2 committed Apr 18, 2024
1 parent b083ad0 commit fb2a10d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
26 changes: 14 additions & 12 deletions examples/cute/tutorial/pvc_sycl/pvc_sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@

using test_clock = std::chrono::high_resolution_clock;

using sycl::ext::oneapi::bfloat16;
using namespace cute;

using dtype_a = bfloat16;
using dtype_b = bfloat16;
using dtype_c = bfloat16;
using dtype_a = bfloat16_t;
using dtype_b = bfloat16_t;
using dtype_c = float;
using dtype_acc = float;

bool identityData = false;
Expand All @@ -50,18 +50,19 @@ std::string makeTestName(const std::string &func, int tM, int tN, int tK,
template <typename T>
static void fill_matrix(std::vector<T> &M, size_t numRows, size_t numCols) {
if (identityData) {
std::generate(std::begin(M), std::end(M), [&] { return 1.0f; });
std::generate(std::begin(M), std::end(M), [&] { return 1.0_bf16; });
} else if (fixedData) {
for (size_t r = 0; r < numRows; r++) {
for (size_t c = 0; c < numCols; c++) {
M[r * numCols + c] = static_cast<float>(r + c);
M[r * numCols + c] = bfloat16_t(float(r + c));
}
}
} else {
std::random_device dev;
std::mt19937 rng(dev());
std::uniform_real_distribution<float> dist(-1.0, 1.0);
std::generate(std::begin(M), std::end(M), [&] { return dist(rng); });
std::generate(std::begin(M), std::end(M),
[&] { return bfloat16_t(dist(rng)); });
}
}

Expand Down Expand Up @@ -179,16 +180,17 @@ go_dpas_blockread_vnni_tiled(sycl::queue queue, std::vector<dtype_acc> &c_vec,
auto B = accB.get_multi_ptr<sycl::access::decorated::yes>().get();
auto C = accC.get_multi_ptr<sycl::access::decorated::yes>().get();

using namespace cute;

Tensor tAr = make_tensor<ushort>(Shape<_8, Int<MM>>{});
Tensor tBr = make_tensor<uint>(Shape<_8, Int<NN>>{});
Tensor tCr =
make_tensor<dtype_acc>(Shape<_8, Int<MM>, Int<NN>>{});

auto A_copy = make_xe_2d_copy<XE_2D_LOAD>(make_tensor(make_gmem_ptr(A), make_shape(M, K)));
auto B_copy = make_xe_2d_copy<XE_2D_LOAD>(make_tensor(make_gmem_ptr(B), make_shape(K, N)));
auto C_copy = make_xe_2d_copy<XE_2D_SAVE>(make_tensor(make_gmem_ptr(C), make_shape(M, N)));
auto A_copy = make_xe_2d_copy<XE_2D_LOAD>(
make_tensor(make_gmem_ptr(A), make_shape(M, K)));
auto B_copy = make_xe_2d_copy<XE_2D_LOAD>(
make_tensor(make_gmem_ptr(B), make_shape(K, N)));
auto C_copy = make_xe_2d_copy<XE_2D_SAVE>(
make_tensor(make_gmem_ptr(C), make_shape(M, N)));
// TODO: - decide on how to deal with vector types
// - create layouts with tiling/partitioning

Expand Down
4 changes: 4 additions & 0 deletions include/cutlass/bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ struct alignas(2) bfloat16_t {

asm("cvt.rn.bf16.f32 %0, %1;\n" : "=h"(storage) : "f"(x));

#elif defined(CUTLASS_ENABLE_SYCL)

storage = sycl::ext::oneapi::detail::bfloat16ToBits(sycl::ext::oneapi::bfloat16(x));

#else
uint32_t bits;

Expand Down

0 comments on commit fb2a10d

Please sign in to comment.