Skip to content

Commit

Permalink
Solve issues
Browse files Browse the repository at this point in the history
  • Loading branch information
aacostadiaz committed Dec 3, 2024
1 parent 389d3bd commit a119855
Show file tree
Hide file tree
Showing 19 changed files with 87 additions and 50 deletions.
4 changes: 2 additions & 2 deletions examples/cute/tutorial/tiled_copy_sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ int main(int argc, char** argv)
return -1;
}
// Equivalent check to the above
if (not weakly_compatible(block_shape, tensor_shape)) {
std::cerr << "Expected the tensors to be weakly compatible with the block_shape." << std::endl;
if (not evenly_divides(tensor_shape, block_shape)) {
std::cerr << "Expected the block_shape to evenly divide the tensor shape." << std::endl;
return -1;
}

Expand Down
25 changes: 9 additions & 16 deletions include/cute/arch/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,7 @@ explode(Fn fn,
return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]...);
}

#if defined(CUTLASS_ENABLE_SYCL)
template <class MMA_Op,
template <class Fn,
class PtrD, int... Id,
class PtrA, int... Ia,
class PtrB, int... Ib,
Expand All @@ -277,29 +276,23 @@ explode(Fn fn,
PtrF&& f, int_sequence<If...>,
PtrG&& g, int_sequence<Ig...>)
{
return MMA_Op::fma(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]..., g[Ig]...);
return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]..., g[Ig]...);
}
#else
template <class Fn,

#if defined(CUTLASS_ENABLE_SYCL)
template <class MMA_Op,
class PtrD, int... Id,
class PtrA, int... Ia,
class PtrB, int... Ib,
class PtrC, int... Ic,
class PtrE, int... Ie,
class PtrF, int... If,
class PtrG, int... Ig>
class PtrC, int... Ic>
CUTE_HOST_DEVICE constexpr
void
explode(Fn fn,
PtrD&& d, int_sequence<Id...>,
explode_mma(PtrD&& d, int_sequence<Id...>,
PtrA&& a, int_sequence<Ia...>,
PtrB&& b, int_sequence<Ib...>,
PtrC&& c, int_sequence<Ic...>,
PtrE&& e, int_sequence<Ie...>,
PtrF&& f, int_sequence<If...>,
PtrG&& g, int_sequence<Ig...>)
PtrC&& c, int_sequence<Ic...>)
{
return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]..., g[Ig]...);
return MMA_Op::fma(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]...);
}
#endif

Expand Down
2 changes: 1 addition & 1 deletion include/cute/atom/mma_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ mma_unpack(AnyMMATraits const& traits,
CUTE_STATIC_ASSERT_V(size(rC) == Int<RegNumC>{});

#if defined(CUTLASS_ENABLE_SYCL)
detail::explode<MMA_Op>(rD, make_int_sequence<RegNumD>{},
detail::explode_mma<MMA_Op>(rD, make_int_sequence<RegNumD>{},
rA, make_int_sequence<RegNumA>{},
rB, make_int_sequence<RegNumB>{},
rC, make_int_sequence<RegNumC>{});
Expand Down
11 changes: 7 additions & 4 deletions include/cute/int_tuple.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,13 @@
#include <cute/container/array.hpp> // cute::array
#include <cute/container/tuple.hpp> // cute::is_tuple
#include <cute/numeric/integral_constant.hpp> // cute::Int
#include <cute/algorithm/tuple_algorithms.hpp> // cute::transform

/** IntTuple is an integer or a tuple of IntTuples.
* This file holds utilities for working with IntTuples,
* but does not hold a concrete concept or class of IntTuple.
*/

namespace cute
{

namespace cute {
// Implementation of get<0>(Integral).
// Even though is_tuple<Integral> is false and tuple_size<Integral> doesn't compile,
// CuTe defines rank(Integral) as 1, so it's useful for get<0>(Integral) to return its input
Expand All @@ -65,6 +62,12 @@ get(T&& t) noexcept
return get<I1, Is...>(get<I0>(static_cast<T&&>(t)));
}

}

#include <cute/algorithm/tuple_algorithms.hpp> // cute::transform

namespace cute {

//
// rank
//
Expand Down
2 changes: 1 addition & 1 deletion include/cute/numeric/integral_constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ constexpr uint64_t parse_int_digits(uint64_t result, int digit, Ts... digits)
// var has type cute::constant<int,32>.
//
template <char... digits>
constexpr cute::constant<int,detail::parse_int_digits(0, (digits - '0')...)> operator "" _c()
constexpr cute::constant<int,detail::parse_int_digits(0, (digits - '0')...)> operator ""_c()
{
static_assert((('0' <= digits && digits <= '9') && ...),
"Expected 0 <= digit <= 9 for each digit of the integer.");
Expand Down
6 changes: 4 additions & 2 deletions include/cutlass/bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,13 @@ struct alignas(2) bfloat16_t {
return (float(*this) != 0.0f);
}

#if !defined(CUTLASS_ENABLE_SYCL)
/// Bitcasts to CUDA's bf16 type
CUTLASS_DEVICE
__nv_bfloat16 to_nv_bfloat16() const {
return reinterpret_cast<__nv_bfloat16 const &>(storage);
}
#endif

/// Obtains raw bits
CUTLASS_HOST_DEVICE
Expand Down Expand Up @@ -676,12 +678,12 @@ bfloat16_t operator--(bfloat16_t & lhs, int) {
//

CUTLASS_HOST_DEVICE
cutlass::bfloat16_t operator "" _bf16(long double x) {
cutlass::bfloat16_t operator ""_bf16(long double x) {
return cutlass::bfloat16_t(float(x));
}

CUTLASS_HOST_DEVICE
cutlass::bfloat16_t operator "" _bf16(unsigned long long int x) {
cutlass::bfloat16_t operator ""_bf16(unsigned long long int x) {
return cutlass::bfloat16_t(int(x));
}

Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/cuda_host_adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////


#if !defined(__CUDACC_RTC__)
#if !defined(__CUDACC_RTC__) && !defined(CUTLASS_ENABLE_SYCL)

#include <cudaTypedefs.h>
#include <driver_types.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class Epilogue<
#if CUDA_BARRIER_ENABLED
auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
#else
auto synchronize = [] () { __syncthreads(); };
auto synchronize = [] () { syncthreads(); };
#endif

// Separate out problem shape for convenience
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,7 @@ class CollectiveEpilogue<
// Bringing tensormaps from params to smem for modification later
copy(recast<uint128_t>(pC_tensormap), recast<uint128_t>(sC_tensormap));
}
__syncwarp();
syncwarp();
return cute::make_tuple(&gmem_tensormap(sm_idx, C_tensormap_index));

}
Expand All @@ -1041,7 +1041,7 @@ class CollectiveEpilogue<
// Bringing tensormaps from params to smem for modification later
copy(recast<uint128_t>(pD_tensormap), recast<uint128_t>(sD_tensormap));
}
__syncwarp();
syncwarp();
return cute::make_tuple(&gmem_tensormap(sm_idx, warp_group_idx));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class CollectiveEpilogue<
SmemLayoutAtomD_,
CopyOpR2S_,
CopyAtomC_,
CopyOpR2R_,
CopyOpR2R_
> {
public:
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -511,12 +511,16 @@ struct Sm90TreeVisitor<
}
if constexpr (cute::is_same_v<ElementCompute, float>) {
uint32_t aux;
#if defined(__SYCL_CUDA_ARCH__) || defined(__CUDA_ARCH__)
asm volatile("set.equ.u32.f32 %0, %1, %2;\n" : "=r"(aux) : "f"(frg_compute[i]), "f"(pre_relu)); // NaN outputs 1 in Aux
#endif
frg_aux[i] = static_cast<bool>(aux);
} else if constexpr (cute::is_same_v<ElementCompute, cutlass::half_t>) {
uint32_t aux;
cutlass::half_t compute = frg_compute[i];
#if defined(__SYCL_CUDA_ARCH__) || defined(__CUDA_ARCH__)
asm volatile("set.equ.u32.f16 %0, %1, %2;\n" : "=r"(aux) : "h"(compute.raw()), "h"(pre_relu.raw())); // NaN outputs 1 in Aux
#endif
frg_aux[i] = static_cast<bool>(aux);
} else {
frg_aux[i] = frg_compute[i] == pre_relu;
Expand Down
28 changes: 19 additions & 9 deletions include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ namespace detail {
CUTLASS_DEVICE
Array<float, 2> top_2_reduce_scalar(Array<float, 2> a, float scalar) {
Array<float, 2> out;
#if defined(__CUDA_ARCH__) || defined(__SYCL_CUDA_ARCH__)
asm volatile(
"{\n"
" .reg .f32 mx;\n"
Expand All @@ -78,12 +79,14 @@ Array<float, 2> top_2_reduce_scalar(Array<float, 2> a, float scalar) {
" selp.f32 %1, mx, %2, p;\n"
" selp.f32 %0, %2, %4, p;\n"
"}\n" : "=f"(out[0]), "=f"(out[1]) : "f"(a[0]), "f"(a[1]), "f"(scalar));
#endif
return out;
}

CUTLASS_DEVICE
Array<float, 2> top_2_reduce(Array<float, 2> a, Array<float, 2> b) {
Array<float, 2> out;
#if defined(__CUDA_ARCH__) || defined(__SYCL_CUDA_ARCH__)
asm volatile(
"{\n"
" .reg .v2 .f32 mx;\n"
Expand All @@ -95,12 +98,14 @@ Array<float, 2> top_2_reduce(Array<float, 2> a, Array<float, 2> b) {
" selp.f32 %0, %2, %4, p;\n" // a0 > b0 ? a0 : b0
"}\n" : "=f"(out[0]), "=f"(out[1]) :
"f"(a[0]), "f"(a[1]), "f"(b[0]), "f"(b[1]));
#endif
return out;
}

CUTLASS_DEVICE
Array<float, 4> top_4_reduce_scalar(Array<float, 4> a, float scalar) {
Array<float, 4> out;
#if defined(__CUDA_ARCH__) || defined(__SYCL_CUDA_ARCH__)
asm volatile(
"{\n"
" .reg .f32 mx;\n" // max(a3, b)
Expand All @@ -120,12 +125,14 @@ Array<float, 4> top_4_reduce_scalar(Array<float, 4> a, float scalar) {
"}\n" :
"=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) :
"f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), "f"(scalar));
#endif
return out;
}

CUTLASS_DEVICE
Array<float, 4> top_4_reduce(Array<float, 4> a, Array<float, 4> b) {
Array<float, 4> out;
#if defined(__CUDA_ARCH__) || defined(__SYCL_CUDA_ARCH__)
asm volatile(
"{\n"
" .reg .f32 mxa0b1;\n" // max(a0, b1)
Expand Down Expand Up @@ -191,6 +198,7 @@ Array<float, 4> top_4_reduce(Array<float, 4> a, Array<float, 4> b) {
"=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) :
"f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]),
"f"(b[0]), "f"(b[1]), "f"(b[2]), "f"(b[3]));
#endif
return out;
}

Expand Down Expand Up @@ -270,6 +278,7 @@ Element topk_logsumexp(cutlass::Array<Element, N> a) {
CUTLASS_DEVICE
float fast_masked_softmax(float value, float minimum, float logsumexp) {
float new_value;
#if defined(__CUDA_ARCH__) || defined(__SYCL_CUDA_ARCH__)
asm volatile(
"{\n"
" .reg .pred p0;\n"
Expand Down Expand Up @@ -302,6 +311,7 @@ float fast_masked_softmax(float value, float minimum, float logsumexp) {
// Mask or softmax
" selp.f32 %0, %%f10, 0f00000000, p0;\n"
"}\n" : "=f"(new_value) : "f"(value), "f"(minimum), "f"(logsumexp));
#endif
return new_value;
}

Expand Down Expand Up @@ -375,7 +385,7 @@ struct Sm90TopKSoftmaxColReduction {
void shuffle_up_sync(uint32_t delta, int lane_id) {
static_assert(sizeof(ReductionResult) == sizeof(uint64_t));
uint64_t r = reinterpret_cast<uint64_t&>(*this);
r = __shfl_up_sync(0xFFFFFFFF, r, delta);
r = shfl_up_sync(0xFFFFFFFF, r, delta);
*this = (lane_id - static_cast<int>(delta) >= 0) ? reinterpret_cast<ReductionResult&>(r) : *this;
}
};
Expand All @@ -402,7 +412,7 @@ struct Sm90TopKSoftmaxColReduction {
if constexpr (TopK == 2) {
static_assert(sizeof(TopKResult) == sizeof(uint64_t));
uint64_t top_k = reinterpret_cast<uint64_t&>(*this);
top_k = __shfl_xor_sync(0xFFFFFFFF, top_k, laneMask);
top_k = shfl_xor_sync(0xFFFFFFFF, top_k, laneMask);
auto synced_v = reinterpret_cast<TopKResult&>(top_k);
detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_);
}
Expand All @@ -412,16 +422,16 @@ struct Sm90TopKSoftmaxColReduction {
uint64_t top_k_arr[2];
top_k_arr[0] = top_k_ptr[0];
top_k_arr[1] = top_k_ptr[1];
top_k_arr[0] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[0], laneMask);
top_k_arr[1] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[1], laneMask);
top_k_arr[0] = shfl_xor_sync(0xFFFFFFFF, top_k_arr[0], laneMask);
top_k_arr[1] = shfl_xor_sync(0xFFFFFFFF, top_k_arr[1], laneMask);
auto synced_v = reinterpret_cast<TopKResult&>(top_k_arr);
detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_);
}
else {
TopKResult synced_v;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < TopK; ++i) {
synced_v.top_k_[i] = __shfl_xor_sync(0xFFFFFFFF, top_k_[i], laneMask);
synced_v.top_k_[i] = shfl_xor_sync(0xFFFFFFFF, top_k_[i], laneMask);
}
detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_);
}
Expand All @@ -433,7 +443,7 @@ struct Sm90TopKSoftmaxColReduction {
if constexpr (TopK == 2) {
static_assert(sizeof(TopKResult) == sizeof(uint64_t));
uint64_t top_k = reinterpret_cast<uint64_t&>(*this);
top_k = __shfl_down_sync(0xFFFFFFFF, top_k, delta);
top_k = shfl_down_sync(0xFFFFFFFF, top_k, delta);
auto synced_v = reinterpret_cast<TopKResult&>(top_k);
detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_);
}
Expand All @@ -443,16 +453,16 @@ struct Sm90TopKSoftmaxColReduction {
uint64_t top_k_arr[2];
top_k_arr[0] = top_k_ptr[0];
top_k_arr[1] = top_k_ptr[1];
top_k_arr[0] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[0], delta);
top_k_arr[1] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[1], delta);
top_k_arr[0] = shfl_down_sync(0xFFFFFFFF, top_k_arr[0], delta);
top_k_arr[1] = shfl_down_sync(0xFFFFFFFF, top_k_arr[1], delta);
auto synced_v = reinterpret_cast<TopKResult&>(top_k_arr);
detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_);
}
else {
TopKResult synced_v;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < TopK; ++i) {
synced_v.top_k_[i] = __shfl_down_sync(0xFFFFFFFF, top_k_[i], delta);
synced_v.top_k_[i] = shfl_down_sync(0xFFFFFFFF, top_k_[i], delta);
}
detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_);
}
Expand Down
8 changes: 4 additions & 4 deletions include/cutlass/float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -1274,22 +1274,22 @@ struct numeric_limits<cutlass::float_e5m2_t> :
//

CUTLASS_HOST_DEVICE
cutlass::float_e4m3_t operator "" _fe4m3(long double x) {
cutlass::float_e4m3_t operator ""_fe4m3(long double x) {
return cutlass::float_e4m3_t(float(x));
}

CUTLASS_HOST_DEVICE
cutlass::float_e4m3_t operator "" _fe4m3(unsigned long long int x) {
cutlass::float_e4m3_t operator ""_fe4m3(unsigned long long int x) {
return cutlass::float_e4m3_t(int(x));
}

CUTLASS_HOST_DEVICE
cutlass::float_e5m2_t operator "" _fe5m2(long double x) {
cutlass::float_e5m2_t operator ""_fe5m2(long double x) {
return cutlass::float_e5m2_t(float(x));
}

CUTLASS_HOST_DEVICE
cutlass::float_e5m2_t operator "" _fe5m2(unsigned long long int x) {
cutlass::float_e5m2_t operator ""_fe5m2(unsigned long long int x) {
return cutlass::float_e5m2_t(int(x));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ struct CollectiveMma<
copy(recast<uint128_t>(pA_tensormap), recast<uint128_t>(sA_tensormap));
copy(recast<uint128_t>(pB_tensormap), recast<uint128_t>(sB_tensormap));
}
__syncwarp();
syncwarp();

return cute::make_tuple(tma_desc_a, tma_desc_b);
}
Expand Down
Loading

0 comments on commit a119855

Please sign in to comment.