Skip to content

Commit

Permalink
w4a16 for sm75
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Nov 2, 2023
1 parent 44782a1 commit 39c1a87
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
19 changes: 19 additions & 0 deletions src/turbomind/kernels/gemm_s_f16/cta_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#pragma once

#include "common.h"
#include <cstddef>
#include <cstdint>

namespace turbomind {
Expand Down Expand Up @@ -236,7 +237,13 @@ struct IteratorA {

__device__ void prefetch(bool mask)
{
#if TURBOMIND_ARCH_SM80
cp_async_cg_A(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
#else
if (mask) {
*(AccessType*)((uint8_t*)smem_ + dst_offset_) = __ldg((const AccessType*)src_ + src_offset_);
}
#endif
}
};

Expand Down Expand Up @@ -417,7 +424,13 @@ struct IteratorQ {

__device__ void prefetch(bool mask)
{
#if TURBOMIND_ARCH_SM80
cp_async_ca(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
#else
if (mask) {
*(AccessType*)((uint8_t*)smem_ + dst_offset_) = __ldg((const AccessType*)src_ + src_offset_);
}
#endif
}
};

Expand Down Expand Up @@ -613,8 +626,14 @@ struct IteratorB {

__device__ void prefetch(bool mask)
{
#if TURBOMIND_ARCH_SM80
cp_async_cg_B(
smem_int_ptr_ + tmp_dst_offset_, (const AccessType*)(src_ + tmp_src_offset_), is_valid_n_ && mask);
#else
if (is_valid_n_ && mask) {
*(AccessType*)((uint8_t*)smem_ + tmp_dst_offset_) = __ldg((const AccessType*)(src_ + tmp_src_offset_));
}
#endif
}
};

Expand Down
22 changes: 21 additions & 1 deletion src/turbomind/kernels/gemm_s_f16/gemm_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,23 @@

namespace turbomind {

__inline__ __device__ void
mma_m16n8k8_row_col(Array<float, 4>& d, const Array<half, 4>& a, const Array<half, 2>& b, Array<float, 4>& c)
{
#if TURBOMIND_ARCH_SM75
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);
float const* C = reinterpret_cast<float const*>(&c);
float* D = reinterpret_cast<float*>(&d);
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
"{%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
#else
assert(TURBOMIND_ARCH_SM75);
#endif
}

__inline__ __device__ void
mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<half, 4>& b, Array<float, 4>& c)
{
Expand All @@ -22,7 +39,10 @@ mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<ha
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
#else
assert(TURBOMIND_ARCH_SM80);
const Array<half, 4>* _a = (const Array<half, 4>*)&a;
const Array<half, 2>* _b = (const Array<half, 2>*)&b;
mma_m16n8k8_row_col(d, _a[0], _b[0], c);
mma_m16n8k8_row_col(d, _a[1], _b[1], d);
#endif
}

Expand Down

0 comments on commit 39c1a87

Please sign in to comment.