Skip to content

Commit

Permalink
use cute::bfloat16_t
Browse files Browse the repository at this point in the history
  • Loading branch information
rolandschulz committed Apr 19, 2024
1 parent d0640a1 commit c2d628f
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions examples/cute/tutorial/pvc_sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
#include <cute/numeric/arithmetic_tuple.hpp>

using test_clock = std::chrono::high_resolution_clock;

using sycl::ext::oneapi::bfloat16;
using namespace cute;

bool identityData = false;
bool fixedData = false;
Expand Down Expand Up @@ -48,15 +47,15 @@ static void fill_matrix(std::vector<T> &M, size_t numRows, size_t numCols)
if (identityData)
{
std::generate(std::begin(M), std::end(M), [&]
{ return 1.0f; });
{ return 1.0_bf16; });
}
else if (fixedData)
{
for (size_t r = 0; r < numRows; r++)
{
for (size_t c = 0; c < numCols; c++)
{
M[r * numCols + c] = static_cast<float>(r + c);
M[r * numCols + c] = bfloat16_t(float(r + c));
}
}
}
Expand All @@ -66,7 +65,7 @@ static void fill_matrix(std::vector<T> &M, size_t numRows, size_t numCols)
std::mt19937 rng(dev());
std::uniform_real_distribution<float> dist(-1.0, 1.0);
std::generate(std::begin(M), std::end(M), [&]
{ return dist(rng); });
{ return bfloat16_t(dist(rng)); });
}
}

Expand Down Expand Up @@ -153,7 +152,7 @@ inline size_t time_event(sycl::event &e)
template <int tM, int tN, int tK, int MM, int NN>
static void go_dpas_blockread_vnni_tiled(
sycl::queue queue,
std::vector<float> &c_vec, sycl::buffer<bfloat16> a, sycl::buffer<bfloat16> b,
std::vector<float> &c_vec, sycl::buffer<bfloat16_t> a, sycl::buffer<bfloat16_t> b,
size_t M, size_t N, size_t K,
const std::vector<float> &C_ref)
{
Expand Down Expand Up @@ -193,9 +192,6 @@ static void go_dpas_blockread_vnni_tiled(
auto B = accB.get_multi_ptr<sycl::access::decorated::yes>().get();
auto C = accC.get_multi_ptr<sycl::access::decorated::yes>().get();


using namespace cute;

Tensor tAr = make_tensor<ushort>(Shape<_8, Int<MM>>{});
Tensor tBr = make_tensor<uint>(Shape<_8, Int<NN>>{});
Tensor tCr = make_tensor<float>(Shape<_8, Int<MM>, Int<NN>>{});
Expand Down Expand Up @@ -256,9 +252,9 @@ int main(int argc, char **argv)
const auto N = matrixSize;
const auto K = matrixSize;

std::vector<bfloat16> A_vec(M * K);
std::vector<bfloat16> B_vec(K * N);
std::vector<bfloat16> Bvnni_vec(K * N);
std::vector<bfloat16_t> A_vec(M * K);
std::vector<bfloat16_t> B_vec(K * N);
std::vector<bfloat16_t> Bvnni_vec(K * N);
std::vector<float> C_vec(M * N);
std::vector<float> C_ref(M * N);

Expand Down

0 comments on commit c2d628f

Please sign in to comment.