Skip to content

Commit 39c1a87

Browse files
committed
w4a16 for sm75
1 parent 44782a1 commit 39c1a87

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

src/turbomind/kernels/gemm_s_f16/cta_iterator.h

+19
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#pragma once
44

55
#include "common.h"
6+
#include <cstddef>
67
#include <cstdint>
78

89
namespace turbomind {
@@ -236,7 +237,13 @@ struct IteratorA {
236237

237238
__device__ void prefetch(bool mask)
238239
{
240+
#if TURBOMIND_ARCH_SM80
239241
cp_async_cg_A(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
242+
#else
243+
if (mask) {
244+
*(AccessType*)((uint8_t*)smem_ + dst_offset_) = __ldg((const AccessType*)src_ + src_offset_);
245+
}
246+
#endif
240247
}
241248
};
242249

@@ -417,7 +424,13 @@ struct IteratorQ {
417424

418425
__device__ void prefetch(bool mask)
419426
{
427+
#if TURBOMIND_ARCH_SM80
420428
cp_async_ca(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
429+
#else
430+
if (mask) {
431+
*(AccessType*)((uint8_t*)smem_ + dst_offset_) = __ldg((const AccessType*)src_ + src_offset_);
432+
}
433+
#endif
421434
}
422435
};
423436

@@ -613,8 +626,14 @@ struct IteratorB {
613626

614627
__device__ void prefetch(bool mask)
615628
{
629+
#if TURBOMIND_ARCH_SM80
616630
cp_async_cg_B(
617631
smem_int_ptr_ + tmp_dst_offset_, (const AccessType*)(src_ + tmp_src_offset_), is_valid_n_ && mask);
632+
#else
633+
if (is_valid_n_ && mask) {
634+
*(AccessType*)((uint8_t*)smem_ + tmp_dst_offset_) = __ldg((const AccessType*)(src_ + tmp_src_offset_));
635+
}
636+
#endif
618637
}
619638
};
620639

src/turbomind/kernels/gemm_s_f16/gemm_template.h

+21-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,23 @@
99

1010
namespace turbomind {
1111

12+
__inline__ __device__ void
13+
mma_m16n8k8_row_col(Array<float, 4>& d, const Array<half, 4>& a, const Array<half, 2>& b, Array<float, 4>& c)
14+
{
15+
#if TURBOMIND_ARCH_SM75
16+
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);
17+
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);
18+
float const* C = reinterpret_cast<float const*>(&c);
19+
float* D = reinterpret_cast<float*>(&d);
20+
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
21+
"{%4,%5}, {%6}, {%7,%8,%9,%10};\n"
22+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
23+
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
24+
#else
25+
assert(TURBOMIND_ARCH_SM75);
26+
#endif
27+
}
28+
1229
__inline__ __device__ void
1330
mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<half, 4>& b, Array<float, 4>& c)
1431
{
@@ -22,7 +39,10 @@ mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<ha
2239
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
2340
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
2441
#else
25-
assert(TURBOMIND_ARCH_SM80);
42+
const Array<half, 4>* _a = (const Array<half, 4>*)&a;
43+
const Array<half, 2>* _b = (const Array<half, 2>*)&b;
44+
mma_m16n8k8_row_col(d, _a[0], _b[0], c);
45+
mma_m16n8k8_row_col(d, _a[1], _b[1], d);
2646
#endif
2747
}
2848

0 commit comments

Comments
 (0)