From a11985523625cacc248dbceccbc97eaeee03b3f1 Mon Sep 17 00:00:00 2001 From: Alejandro Acosta Date: Fri, 29 Nov 2024 11:46:04 +0000 Subject: [PATCH] Solve issues --- examples/cute/tutorial/tiled_copy_sycl.cpp | 4 +-- include/cute/arch/util.hpp | 25 ++++++----------- include/cute/atom/mma_traits.hpp | 2 +- include/cute/int_tuple.hpp | 11 +++++--- include/cute/numeric/integral_constant.hpp | 2 +- include/cutlass/bfloat16.h | 6 ++-- include/cutlass/cuda_host_adapter.hpp | 2 +- .../sm70_epilogue_vectorized_array.hpp | 2 +- ...m90_epilogue_array_tma_warpspecialized.hpp | 4 +-- .../sm90_epilogue_tma_warpspecialized.hpp | 2 +- ...90_visitor_compute_tma_warpspecialized.hpp | 4 +++ .../fusion/sm90_visitor_topk_softmax.hpp | 28 +++++++++++++------ include/cutlass/float8.h | 8 +++--- ..._mma_array_tma_gmma_ss_warpspecialized.hpp | 2 +- include/cutlass/gpu_generics.h | 16 +++++++++++ include/cutlass/half.h | 4 +-- include/cutlass/kernel_launch.h | 4 +++ include/cutlass/numeric_conversion.h | 7 ++++- include/cutlass/tfloat32.h | 4 +-- 19 files changed, 87 insertions(+), 50 deletions(-) diff --git a/examples/cute/tutorial/tiled_copy_sycl.cpp b/examples/cute/tutorial/tiled_copy_sycl.cpp index 3093e0351b..77c5484832 100644 --- a/examples/cute/tutorial/tiled_copy_sycl.cpp +++ b/examples/cute/tutorial/tiled_copy_sycl.cpp @@ -192,8 +192,8 @@ int main(int argc, char** argv) return -1; } // Equivalent check to the above - if (not weakly_compatible(block_shape, tensor_shape)) { - std::cerr << "Expected the tensors to be weakly compatible with the block_shape." << std::endl; + if (not evenly_divides(tensor_shape, block_shape)) { + std::cerr << "Expected the block_shape to evenly divide the tensor shape." << std::endl; return -1; } diff --git a/include/cute/arch/util.hpp b/include/cute/arch/util.hpp index f4bfa721f5..9290439b9f 100644 --- a/include/cute/arch/util.hpp +++ b/include/cute/arch/util.hpp @@ -257,8 +257,7 @@ explode(Fn fn, return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]...); } -#if defined(CUTLASS_ENABLE_SYCL) -template , PtrG&& g, int_sequence) { - return MMA_Op::fma(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]..., g[Ig]...); + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]..., g[Ig]...); } -#else -template + class PtrC, int... Ic> CUTE_HOST_DEVICE constexpr void -explode(Fn fn, - PtrD&& d, int_sequence, +explode_mma(PtrD&& d, int_sequence, PtrA&& a, int_sequence, PtrB&& b, int_sequence, - PtrC&& c, int_sequence, - PtrE&& e, int_sequence, - PtrF&& f, int_sequence, - PtrG&& g, int_sequence) + PtrC&& c, int_sequence) { - return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]..., g[Ig]...); + return MMA_Op::fma(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]...); } #endif diff --git a/include/cute/atom/mma_traits.hpp b/include/cute/atom/mma_traits.hpp index 06c4593c11..b5569fc2f5 100644 --- a/include/cute/atom/mma_traits.hpp +++ b/include/cute/atom/mma_traits.hpp @@ -144,7 +144,7 @@ mma_unpack(AnyMMATraits const& traits, CUTE_STATIC_ASSERT_V(size(rC) == Int{}); #if defined(CUTLASS_ENABLE_SYCL) - detail::explode(rD, make_int_sequence{}, + detail::explode_mma(rD, make_int_sequence{}, rA, make_int_sequence{}, rB, make_int_sequence{}, rC, make_int_sequence{}); diff --git a/include/cute/int_tuple.hpp b/include/cute/int_tuple.hpp index 95d06bbdd7..132e103830 100644 --- a/include/cute/int_tuple.hpp +++ b/include/cute/int_tuple.hpp @@ -34,16 +34,13 @@ #include // cute::array #include // cute::is_tuple #include // cute::Int -#include // cute::transform /** IntTuple is an integer or a tuple of IntTuples. * This file holds utilities for working with IntTuples, * but does not hold a concrete concept or class of IntTuple. */ -namespace cute -{ - +namespace cute { // Implementation of get<0>(Integral). // Even though is_tuple is false and tuple_size doesn't compile, // CuTe defines rank(Integral) as 1, so it's useful for get<0>(Integral) to return its input @@ -65,6 +62,12 @@ get(T&& t) noexcept return get(get(static_cast(t))); } +} + +#include // cute::transform + +namespace cute { + // // rank // diff --git a/include/cute/numeric/integral_constant.hpp b/include/cute/numeric/integral_constant.hpp index 3a8d036eef..88b00922f7 100644 --- a/include/cute/numeric/integral_constant.hpp +++ b/include/cute/numeric/integral_constant.hpp @@ -507,7 +507,7 @@ constexpr uint64_t parse_int_digits(uint64_t result, int digit, Ts... digits) // var has type cute::constant. // template -constexpr cute::constant operator "" _c() +constexpr cute::constant operator ""_c() { static_assert((('0' <= digits && digits <= '9') && ...), "Expected 0 <= digit <= 9 for each digit of the integer."); diff --git a/include/cutlass/bfloat16.h b/include/cutlass/bfloat16.h index a4c3463270..87b1f2f42f 100644 --- a/include/cutlass/bfloat16.h +++ b/include/cutlass/bfloat16.h @@ -199,11 +199,13 @@ struct alignas(2) bfloat16_t { return (float(*this) != 0.0f); } +#if !defined(CUTLASS_ENABLE_SYCL) /// Bitcasts to CUDA's bf16 type CUTLASS_DEVICE __nv_bfloat16 to_nv_bfloat16() const { return reinterpret_cast<__nv_bfloat16 const &>(storage); } +#endif /// Obtains raw bits CUTLASS_HOST_DEVICE @@ -676,12 +678,12 @@ bfloat16_t operator--(bfloat16_t & lhs, int) { // CUTLASS_HOST_DEVICE -cutlass::bfloat16_t operator "" _bf16(long double x) { +cutlass::bfloat16_t operator ""_bf16(long double x) { return cutlass::bfloat16_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::bfloat16_t operator "" _bf16(unsigned long long int x) { +cutlass::bfloat16_t operator ""_bf16(unsigned long long int x) { return cutlass::bfloat16_t(int(x)); } diff --git a/include/cutlass/cuda_host_adapter.hpp b/include/cutlass/cuda_host_adapter.hpp index 3b37b1fab8..2c5f61d6ed 100644 --- a/include/cutlass/cuda_host_adapter.hpp +++ b/include/cutlass/cuda_host_adapter.hpp @@ -85,7 +85,7 @@ namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// -#if !defined(__CUDACC_RTC__) +#if !defined(__CUDACC_RTC__) && !defined(CUTLASS_ENABLE_SYCL) #include #include diff --git a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp index 8a70370b21..5583f96328 100644 --- a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp +++ b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp @@ -201,7 +201,7 @@ class Epilogue< #if CUDA_BARRIER_ENABLED auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; #else - auto synchronize = [] () { __syncthreads(); }; + auto synchronize = [] () { syncthreads(); }; #endif // Separate out problem shape for convenience diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp index 84b6e14eeb..ae095cf915 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -1026,7 +1026,7 @@ class CollectiveEpilogue< // Bringing tensormaps from params to smem for modification later copy(recast(pC_tensormap), recast(sC_tensormap)); } - __syncwarp(); + syncwarp(); return cute::make_tuple(&gmem_tensormap(sm_idx, C_tensormap_index)); } @@ -1041,7 +1041,7 @@ class CollectiveEpilogue< // Bringing tensormaps from params to smem for modification later copy(recast(pD_tensormap), recast(sD_tensormap)); } - __syncwarp(); + syncwarp(); return cute::make_tuple(&gmem_tensormap(sm_idx, warp_group_idx)); } } diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index b96c4aea00..1ecf854085 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -94,7 +94,7 @@ class CollectiveEpilogue< SmemLayoutAtomD_, CopyOpR2S_, CopyAtomC_, - CopyOpR2R_, + CopyOpR2R_ > { public: // diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index 9f2e00d18b..49cbb38a90 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -511,12 +511,16 @@ struct Sm90TreeVisitor< } if constexpr (cute::is_same_v) { uint32_t aux; +#if defined(__SYCL_CUDA_ARCH__) || defined(__CUDA_ARCH__) asm volatile("set.equ.u32.f32 %0, %1, %2;\n" : "=r"(aux) : "f"(frg_compute[i]), "f"(pre_relu)); // NaN outputs 1 in Aux +#endif frg_aux[i] = static_cast(aux); } else if constexpr (cute::is_same_v) { uint32_t aux; cutlass::half_t compute = frg_compute[i]; +#if defined(__SYCL_CUDA_ARCH__) || defined(__CUDA_ARCH__) asm volatile("set.equ.u32.f16 %0, %1, %2;\n" : "=r"(aux) : "h"(compute.raw()), "h"(pre_relu.raw())); // NaN outputs 1 in Aux +#endif frg_aux[i] = static_cast(aux); } else { frg_aux[i] = frg_compute[i] == pre_relu; diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp index 53c0dce8ba..4624974432 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp @@ -69,6 +69,7 @@ namespace detail { CUTLASS_DEVICE Array top_2_reduce_scalar(Array a, float scalar) { Array out; +#if defined(__CUDA_ARCH__) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .f32 mx;\n" @@ -78,12 +79,14 @@ Array top_2_reduce_scalar(Array a, float scalar) { " selp.f32 %1, mx, %2, p;\n" " selp.f32 %0, %2, %4, p;\n" "}\n" : "=f"(out[0]), "=f"(out[1]) : "f"(a[0]), "f"(a[1]), "f"(scalar)); +#endif return out; } CUTLASS_DEVICE Array top_2_reduce(Array a, Array b) { Array out; +#if defined(__CUDA_ARCH__) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .v2 .f32 mx;\n" @@ -95,12 +98,14 @@ Array top_2_reduce(Array a, Array b) { " selp.f32 %0, %2, %4, p;\n" // a0 > b0 ? a0 : b0 "}\n" : "=f"(out[0]), "=f"(out[1]) : "f"(a[0]), "f"(a[1]), "f"(b[0]), "f"(b[1])); +#endif return out; } CUTLASS_DEVICE Array top_4_reduce_scalar(Array a, float scalar) { Array out; +#if defined(__CUDA_ARCH__) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .f32 mx;\n" // max(a3, b) @@ -120,12 +125,14 @@ Array top_4_reduce_scalar(Array a, float scalar) { "}\n" : "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), "f"(scalar)); +#endif return out; } CUTLASS_DEVICE Array top_4_reduce(Array a, Array b) { Array out; +#if defined(__CUDA_ARCH__) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .f32 mxa0b1;\n" // max(a0, b1) @@ -191,6 +198,7 @@ Array top_4_reduce(Array a, Array b) { "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), "f"(b[0]), "f"(b[1]), "f"(b[2]), "f"(b[3])); +#endif return out; } @@ -270,6 +278,7 @@ Element topk_logsumexp(cutlass::Array a) { CUTLASS_DEVICE float fast_masked_softmax(float value, float minimum, float logsumexp) { float new_value; +#if defined(__CUDA_ARCH__) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .pred p0;\n" @@ -302,6 +311,7 @@ float fast_masked_softmax(float value, float minimum, float logsumexp) { // Mask or softmax " selp.f32 %0, %%f10, 0f00000000, p0;\n" "}\n" : "=f"(new_value) : "f"(value), "f"(minimum), "f"(logsumexp)); +#endif return new_value; } @@ -375,7 +385,7 @@ struct Sm90TopKSoftmaxColReduction { void shuffle_up_sync(uint32_t delta, int lane_id) { static_assert(sizeof(ReductionResult) == sizeof(uint64_t)); uint64_t r = reinterpret_cast(*this); - r = __shfl_up_sync(0xFFFFFFFF, r, delta); + r = shfl_up_sync(0xFFFFFFFF, r, delta); *this = (lane_id - static_cast(delta) >= 0) ? reinterpret_cast(r) : *this; } }; @@ -402,7 +412,7 @@ struct Sm90TopKSoftmaxColReduction { if constexpr (TopK == 2) { static_assert(sizeof(TopKResult) == sizeof(uint64_t)); uint64_t top_k = reinterpret_cast(*this); - top_k = __shfl_xor_sync(0xFFFFFFFF, top_k, laneMask); + top_k = shfl_xor_sync(0xFFFFFFFF, top_k, laneMask); auto synced_v = reinterpret_cast(top_k); detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); } @@ -412,8 +422,8 @@ struct Sm90TopKSoftmaxColReduction { uint64_t top_k_arr[2]; top_k_arr[0] = top_k_ptr[0]; top_k_arr[1] = top_k_ptr[1]; - top_k_arr[0] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[0], laneMask); - top_k_arr[1] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[1], laneMask); + top_k_arr[0] = shfl_xor_sync(0xFFFFFFFF, top_k_arr[0], laneMask); + top_k_arr[1] = shfl_xor_sync(0xFFFFFFFF, top_k_arr[1], laneMask); auto synced_v = reinterpret_cast(top_k_arr); detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); } @@ -421,7 +431,7 @@ struct Sm90TopKSoftmaxColReduction { TopKResult synced_v; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < TopK; ++i) { - synced_v.top_k_[i] = __shfl_xor_sync(0xFFFFFFFF, top_k_[i], laneMask); + synced_v.top_k_[i] = shfl_xor_sync(0xFFFFFFFF, top_k_[i], laneMask); } detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); } @@ -433,7 +443,7 @@ struct Sm90TopKSoftmaxColReduction { if constexpr (TopK == 2) { static_assert(sizeof(TopKResult) == sizeof(uint64_t)); uint64_t top_k = reinterpret_cast(*this); - top_k = __shfl_down_sync(0xFFFFFFFF, top_k, delta); + top_k = shfl_down_sync(0xFFFFFFFF, top_k, delta); auto synced_v = reinterpret_cast(top_k); detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); } @@ -443,8 +453,8 @@ struct Sm90TopKSoftmaxColReduction { uint64_t top_k_arr[2]; top_k_arr[0] = top_k_ptr[0]; top_k_arr[1] = top_k_ptr[1]; - top_k_arr[0] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[0], delta); - top_k_arr[1] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[1], delta); + top_k_arr[0] = shfl_down_sync(0xFFFFFFFF, top_k_arr[0], delta); + top_k_arr[1] = shfl_down_sync(0xFFFFFFFF, top_k_arr[1], delta); auto synced_v = reinterpret_cast(top_k_arr); detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); } @@ -452,7 +462,7 @@ struct Sm90TopKSoftmaxColReduction { TopKResult synced_v; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < TopK; ++i) { - synced_v.top_k_[i] = __shfl_down_sync(0xFFFFFFFF, top_k_[i], delta); + synced_v.top_k_[i] = shfl_down_sync(0xFFFFFFFF, top_k_[i], delta); } detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); } diff --git a/include/cutlass/float8.h b/include/cutlass/float8.h index 8efee8c261..cc37aadab1 100644 --- a/include/cutlass/float8.h +++ b/include/cutlass/float8.h @@ -1274,22 +1274,22 @@ struct numeric_limits : // CUTLASS_HOST_DEVICE -cutlass::float_e4m3_t operator "" _fe4m3(long double x) { +cutlass::float_e4m3_t operator ""_fe4m3(long double x) { return cutlass::float_e4m3_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::float_e4m3_t operator "" _fe4m3(unsigned long long int x) { +cutlass::float_e4m3_t operator ""_fe4m3(unsigned long long int x) { return cutlass::float_e4m3_t(int(x)); } CUTLASS_HOST_DEVICE -cutlass::float_e5m2_t operator "" _fe5m2(long double x) { +cutlass::float_e5m2_t operator ""_fe5m2(long double x) { return cutlass::float_e5m2_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::float_e5m2_t operator "" _fe5m2(unsigned long long int x) { +cutlass::float_e5m2_t operator ""_fe5m2(unsigned long long int x) { return cutlass::float_e5m2_t(int(x)); } diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp index 9e56f96704..80e128c3a0 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp @@ -643,7 +643,7 @@ struct CollectiveMma< copy(recast(pA_tensormap), recast(sA_tensormap)); copy(recast(pB_tensormap), recast(sB_tensormap)); } - __syncwarp(); + syncwarp(); return cute::make_tuple(tma_desc_a, tma_desc_b); } diff --git a/include/cutlass/gpu_generics.h b/include/cutlass/gpu_generics.h index 44b5a92acb..c476a2fc67 100644 --- a/include/cutlass/gpu_generics.h +++ b/include/cutlass/gpu_generics.h @@ -295,6 +295,22 @@ unsigned int shfl_sync( #endif } +CUTLASS_DEVICE +unsigned int shfl_xor_sync( + unsigned int const mask, + unsigned int const var, + int const laneMask, + int const width = NumThreadsPerWarp) { +#if defined(__CUDA_ARCH__) + return __shfl_xor_sync(mask, var, laneMask, width); +#elif defined(__SYCL_DEVICE_ONLY__) + auto g = syclcompat::get_nd_item<1>().get_sub_group(); + return syclcompat::permute_sub_group_by_xor(g, var, laneMask); +#else + return 0; +#endif +} + //////////////////////////////////////////////////////////////////////////////////////////////////// /* diff --git a/include/cutlass/half.h b/include/cutlass/half.h index 20e3c70b01..10a80f04ec 100644 --- a/include/cutlass/half.h +++ b/include/cutlass/half.h @@ -928,12 +928,12 @@ half_t operator--(half_t & lhs, int) { // CUTLASS_HOST_DEVICE -cutlass::half_t operator "" _hf(long double x) { +cutlass::half_t operator ""_hf(long double x) { return cutlass::half_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::half_t operator "" _hf(unsigned long long int x) { +cutlass::half_t operator ""_hf(unsigned long long int x) { return cutlass::half_t(int(x)); } diff --git a/include/cutlass/kernel_launch.h b/include/cutlass/kernel_launch.h index ca3380a2a1..77c8406c48 100644 --- a/include/cutlass/kernel_launch.h +++ b/include/cutlass/kernel_launch.h @@ -34,7 +34,9 @@ #pragma once +#if !defined(CUTLASS_ENABLE_SYCL) #include +#endif #include "cutlass/cutlass.h" #include "cutlass/trace.h" @@ -87,7 +89,9 @@ Status kernel_launch( #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) CUTLASS_TRACE_HOST("cutlass::kernel_launch: No PDL"); #endif +#if !defined(CUTLASS_ENABLE_SYCL) device_kernel<<>>(kernel_params); +#endif } else { #if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index e13a9d0cb9..5b9fcf57b0 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -3638,7 +3638,8 @@ struct NumericArrayConverter { "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); PackedResultType r; - #if defined __CUDA_ARCH__ && __CUDA_ARCH__ <= 800 + #if (defined __CUDA_ARCH__ && __CUDA_ARCH__ <= 800) || \ + (defined __SYCL_CUDA_ARCH__ && __SYCL_CUDA_ARCH__ <= 800) // View the input as reg uint32_t src_reg = to_reg(source); static constexpr int fp32_base = 0x4B400000; @@ -3663,7 +3664,11 @@ struct NumericArrayConverter { CUTLASS_PRAGMA_UNROLL for (int ii = 0; ii < PackedResultType::kElements; ++ii) { +#if defined(__CUDA_ARCH__) t[ii] = __dp4a(x, mask[ii], 0); +#else + t[ii] = x * mask[ii]; +#endif r[ii] = static_cast(t[ii]); } #endif diff --git a/include/cutlass/tfloat32.h b/include/cutlass/tfloat32.h index 259a4ba180..4ff36f1194 100644 --- a/include/cutlass/tfloat32.h +++ b/include/cutlass/tfloat32.h @@ -466,12 +466,12 @@ tfloat32_t operator--(tfloat32_t & lhs, int) { // CUTLASS_HOST_DEVICE -cutlass::tfloat32_t operator "" _tf32(long double x) { +cutlass::tfloat32_t operator ""_tf32(long double x) { return cutlass::tfloat32_t(float(x)); } CUTLASS_HOST_DEVICE -cutlass::tfloat32_t operator "" _tf32(unsigned long long int x) { +cutlass::tfloat32_t operator ""_tf32(unsigned long long int x) { return cutlass::tfloat32_t(int(x)); }