Skip to content

Commit

Permalink
removing unnecessary specialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Saeed Maleki committed Nov 12, 2023
1 parent 410f257 commit 282b28d
Showing 1 changed file with 2 additions and 68 deletions.
70 changes: 2 additions & 68 deletions include/mscclpp/sm_channel_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,22 @@ namespace Element {

/// Load an element from DRAM.
///
/// This is a warpper of ld.volatile.global.* PTX instruction. Address alignment is not this function's
/// responsibility.
///
/// @param v The value to be loaded.
/// @param p The address of the value to be loaded.
///
template <typename T>
__forceinline__ __device__ void load(T& v, const T* p) {
// We should only use the specialized functions.
v = *(volatile T*)p;
// __assert_fail("Unsupported type", __FILE__, __LINE__, __PRETTY_FUNCTION__);
v = *p;
}

/// Write an element on DRAM.
///
/// This is a wrapper of st.volatile.global.* PTX instruction. Address alignment is not this function's
/// responsibility.
///
/// @param p The address of the value to be written.
/// @param v The value to be written.
///
template <typename T>
__forceinline__ __device__ void store(T* p, const T& v) {
// We should only use the specialized functions.
*(volatile T*)p = v;
// __assert_fail("Unsupported type", __FILE__, __LINE__, __PRETTY_FUNCTION__);
*p = v;
}

/// Copy aligned elements from the source memory to the destination memory.
Expand All @@ -66,62 +56,6 @@ __forceinline__ __device__ void copy(T* dst, T* src, uint64_t numElems, uint32_t
}
}

template <>
__forceinline__ __device__ void load<long long>(long long& v, const long long* p) {
asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(v) : "l"(p) : "memory");
}

template <>
__forceinline__ __device__ void store<long long>(long long* p, const long long& v) {
asm volatile("st.volatile.global.u64 [%0], %1;" : : "l"(p), "l"(v) : "memory");
}

template <>
__forceinline__ __device__ void load<int>(int& v, const int* p) {
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(v) : "l"(p) : "memory");
}

template <>
__forceinline__ __device__ void store<int>(int* p, const int& v) {
asm volatile("st.volatile.global.u32 [%0], %1;" : : "l"(p), "r"(v) : "memory");
}

template <>
__forceinline__ __device__ void load<longlong2>(longlong2& v, const longlong2* p) {
asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.x), "=l"(v.y) : "l"(p) : "memory");
}

template <>
__forceinline__ __device__ void store<longlong2>(longlong2* p, const longlong2& v) {
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" : : "l"(p), "l"(v.x), "l"(v.y) : "memory");
}

template <>
__forceinline__ __device__ void load<LLPacket>(LLPacket& v, const LLPacket* p) {
asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.v[0]), "=l"(v.v[1]) : "l"(p) : "memory");
}

template <>
__forceinline__ __device__ void store<LLPacket>(LLPacket* p, const LLPacket& v) {
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" : : "l"(p), "l"(v.v[0]), "l"(v.v[1]) : "memory");
}

template <>
__forceinline__ __device__ void load<int4>(int4& v, const int4* p) {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
: "=r"(v.x), "=r"(v.y), "=r"(v.z), "=r"(v.w)
: "l"(p)
: "memory");
}

template <>
__forceinline__ __device__ void store<int4>(int4* p, const int4& v) {
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
:
: "l"(p), "r"(v.x), "r"(v.y), "r"(v.z), "r"(v.w)
: "memory");
}

} // namespace Element

#endif // __CUDACC__
Expand Down

0 comments on commit 282b28d

Please sign in to comment.