Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Intel gpu backend gemm pipeline #89

Merged

Conversation

Jiaxingla
Copy link
Collaborator

@Jiaxingla Jiaxingla commented Jul 3, 2024

  1. High performance gemm pipeline for pvc.
  2. Enable the prefetch by copy atom.
  3. Try the epilogue like ReLU and Softmax.

@Jiaxingla Jiaxingla requested review from taozha2 and jiyang1011 July 3, 2024 08:10
@Jiaxingla Jiaxingla marked this pull request as ready for review July 4, 2024 05:55
examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp Outdated Show resolved Hide resolved
examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp Outdated Show resolved Hide resolved
examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp Outdated Show resolved Hide resolved
examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp Outdated Show resolved Hide resolved
examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp Outdated Show resolved Hide resolved
examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp Outdated Show resolved Hide resolved
examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp Outdated Show resolved Hide resolved
examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp Outdated Show resolved Hide resolved
include/cutlass/epilogue/collective/default_epilogue.hpp Outdated Show resolved Hide resolved
@Jiaxingla Jiaxingla force-pushed the intel_gpu_backend_pipeline branch from 189753c to d82b5a4 Compare July 10, 2024 01:28
@@ -44,6 +44,18 @@ namespace cute
inline x { assert(false); }
#endif

enum LSC_LDCC {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp Outdated Show resolved Hide resolved
examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp Outdated Show resolved Hide resolved
//
// Methods
//

bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) {
bool verify(const ProblemShapeType &problem_size, ElementCompute alpha,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you use automatic code formatting? https://github.com/NVIDIA/cutlass/blob/main/media/docs/programming_guidelines.md#no-automatic-code-formatting

I hope we can use it because I think having to review formatting is a pain. But if so we need to find a tool which follows all of https://github.com/NVIDIA/cutlass/blob/main/media/docs/programming_guidelines.md and only apply to lines we change (e.g. git clang-format) so that we don't reformat any of the upstream code.

Here format is again wrong (west const, space alignment, line length). Please try to fix it everywhere and I'll review formatting after.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--Did you use automatic code formatting?
Yes, we used a wrong code formatting.
I'll correct the clang-format and try my best to fix the format everywhere with cutlass guidelines.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you are using clang-format please add the config

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we aren't using Chang-format?

examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp Outdated Show resolved Hide resolved
template <class T>
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
int height, int pitch, intel::coord_t coord) {
#if defined(SYCL_INTEL_TARGET)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern is the same as on the unmodified line 170.

I think we can improve this by changing __SYCL_DEVICE_ONLY__ to

#ifdef __SYCL_DEVICE_ONLY__
#ifdef SYCL_INTEL_TARGET
#define SYCL_DEVICE_BUILTIN(x) SYCL_EXTERNAL extern "C" x
#else
#define SYCL_DEVICE_BUILTIN(x)  \
  inline x { CUTE_INVALID_CONTROL_PATH("Trying to use IGC built-in on non-Intel hardware"); }
#endif
#else
#define SYCL_DEVICE_BUILTIN(x)  \
  inline x { CUTE_INVALID_CONTROL_PATH("Trying to use device built-in on host."); }
#endif 

Up to you whether you want to do this as part of this PR.

include/cutlass/gemm/collective/intel_pvc_mma.hpp Outdated Show resolved Hide resolved
include/cutlass/relatively_equal.h Outdated Show resolved Hide resolved
include/cutlass/gemm/collective/intel_pvc_mma.hpp Outdated Show resolved Hide resolved
include/cutlass/gemm/collective/intel_pvc_mma.hpp Outdated Show resolved Hide resolved
include/cutlass/gemm/kernel/intel_pvc_gemm.hpp Outdated Show resolved Hide resolved
@Jiaxingla Jiaxingla force-pushed the intel_gpu_backend_pipeline branch 5 times, most recently from e7381bc to 26d8a88 Compare July 16, 2024 08:15
@Jiaxingla Jiaxingla force-pushed the intel_gpu_backend_pipeline branch from 576261d to 6bdda75 Compare July 17, 2024 06:31
@@ -362,14 +340,14 @@ int main(int argc, const char** argv)
using LayoutD = cutlass::layout::RowMajor;

using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding some explanation for naming conventions of copy function

} else if constexpr (is_tuple<typename Tensor<SEngine,
SLayout>::engine_type::iterator::
value_type>::value) {
return copy_unpack(*this, src, dst);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed please remove this elseif condition and make the changes in the copy traits definition so the execution go into the first if check

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this you need to make SrcLayout in the prefetch copy trait to match decltype(size(src))>::value

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jiaxingla and others added 4 commits July 31, 2024 14:10
}
int prefetch_k = 0;

// Manually set the value to 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this comment say Manually set the value to 3?

Copy link
Collaborator

@aacostadiaz aacostadiaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the comments. LGTM!

Copy link
Collaborator

@mehdi-goli mehdi-goli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@rolandschulz
Copy link
Collaborator

Some of my comments haven't been addressed. Are you planning to address them before merging or after?

@Jiaxingla
Copy link
Collaborator Author

Jiaxingla commented Aug 1, 2024

Some of my comments haven't been addressed. Are you planning to address them before merging or after?

Hi @rolandschulz , i do want to address or answer all your comments before merging. Because PR is completely different with beginning, if i have missed any of your comments, please let me know.
Apologize for any inconvenience caused.

@rolandschulz
Copy link
Collaborator

I clicked "resolve" on everything which I can see being addressed. Please check the remaining 3.

@Jiaxingla
Copy link
Collaborator Author

I clicked "resolve" on everything which I can see being addressed. Please check the remaining 3.

I resolved the format of enum class and answered the questions about clang-format.
The last about SYCL macro will fix in the next PR about Cute feature.

@@ -44,6 +44,17 @@ namespace cute
inline x { assert(false); }
#endif

enum CacheControl {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be a scoped enum?

include/cute/arch/copy_xe.hpp Outdated Show resolved Hide resolved
int height, int pitch, intel::coord_t coord) {
#if defined(SYCL_INTEL_TARGET)
static_assert(sizeof(T) == 2, "Expected T to have size 2");
__builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this not match the non-prefetch in the same struct?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to still not match (u16/u32).

@aacostadiaz
Copy link
Collaborator

CI is failing due to changes introduced in the latest nightly version of DPCPP. Rebasing on sycl_develop should resolve this.

};

/// @brief This function loads data from 2D memory surface.
/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers.
/// Loads 1x1 memory blocks, and each block size is 8x16x32bits
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we use different ordering of dimensions for the comment and the name of the struct?

@aacostadiaz
Copy link
Collaborator

I think all comments have been addressed. If there is something missing, please let us know and it will be added in the follow up PR.

@aacostadiaz aacostadiaz merged commit 0b5c911 into codeplaysoftware:sycl-develop Aug 2, 2024
4 checks passed
aacostadiaz pushed a commit to aacostadiaz/cutlass-fork that referenced this pull request Aug 6, 2024
Enable the prefetch by copy atom.

---------

Co-authored-by: Mehdi Goli <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants