Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Init arch Xe2 #298

Open
wants to merge 34 commits into
base: xetla
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
536e03e
save
sunjiweiswift May 17, 2024
5949084
save(some error with kslicing)
sunjiweiswift May 21, 2024
0669b12
fix kslicing bug
sunjiweiswift May 22, 2024
aafe774
save(g128 MTL 270Gflops bug on g32)
sunjiweiswift May 24, 2024
1b9a443
add Specialized for FPU
sunjiweiswift May 24, 2024
194ca35
support int scale col_major(with opt 10% perf when g = 32)
sunjiweiswift May 27, 2024
2bc4877
support int4x8 for int32 weight
sunjiweiswift May 27, 2024
8b9df8b
Update include/experimental/group/gemm/compute_policy.hpp
sunjiweiswift May 28, 2024
e8d3fbb
Update include/experimental/group/gemm/compute_policy.hpp
sunjiweiswift May 28, 2024
b0621df
save(perf bug with int4x8 load)
sunjiweiswift May 28, 2024
56be57a
save
sunjiweiswift May 29, 2024
2b37173
add first token UT
sunjiweiswift May 30, 2024
f973aa2
opt mma code
sunjiweiswift May 30, 2024
0f36c04
opt perf for int4x8
sunjiweiswift May 30, 2024
d9902d8
support load one fp16 data
sunjiweiswift May 31, 2024
30b8e95
support zero_pt
sunjiweiswift May 31, 2024
885995f
support ASYM and SYM
sunjiweiswift Jun 3, 2024
7e99e68
save
sunjiweiswift Jun 4, 2024
150f7d3
ut improve
sunjiweiswift Jun 6, 2024
ddbac97
support sg_n > 1
sunjiweiswift Jun 6, 2024
d2aff4b
add #pragma unroll
sunjiweiswift Jun 7, 2024
97c2481
support HF zero pt layout K x N, compress int4 along N dimensions
sunjiweiswift Jun 7, 2024
f19c86f
save
sunjiweiswift Jun 11, 2024
897f5d5
sg_m =4 for first token
sunjiweiswift Jun 14, 2024
e7f2716
Extract dequant func
sunjiweiswift Jun 14, 2024
0ebd890
update row_major for origin PVC/ARC template
sunjiweiswift Jun 17, 2024
b2dfad5
save(fix HPC 2D load)
sunjiweiswift Jun 17, 2024
8817f54
fix XEHPC 2D load
sunjiweiswift Jun 17, 2024
957c5a4
fix compile for all UT
sunjiweiswift Jun 17, 2024
5456fc0
sync ipex 20240618
DDEle Jun 19, 2024
9185409
opt PVC arch
sunjiweiswift Jun 19, 2024
93c8ad1
fix group_qkv
sunjiweiswift Jun 19, 2024
8f0abc4
fix group_qkv
sunjiweiswift Jun 20, 2024
dc7d812
init arch Xe2
airMeng Jun 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ else()
set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -Xs "${XETLA_OFFLINE_OPTIONS}")
endif()

add_compile_options(-fsycl -fsycl-device-code-split=per_kernel)
add_compile_options(-fsycl -fsycl-device-code-split=per_kernel -ftemplate-backtrace-limit=0)
add_compile_options(-Wall -Wextra -Werror)

include(ProcessorCount)
Expand Down
28 changes: 14 additions & 14 deletions examples/05_batch_gemm/batch_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,20 +276,20 @@ class batch_gemm_t {
args.matB_base.base, args.matB_ld);
}
}
if (epilogue_t::msg_type_c != msg_type::unaligned_2d) {
if (epilogue_t::msg_type_c == msg_type::block_2d) {
implementable &=
kernel::block_2d<gpu_arch::XeHpc, dtype_c>::check_tensor(
(uint64_t)(args.matC_base.base),
args.matrix_n,
args.matrix_m * args.batch_size,
args.matC_ld);
} else {
implementable &=
kernel::general_1d<gpu_arch::XeHpc, dtype_c>::check_alignment(
args.matC_base.base, args.matC_ld);
}
}
// if (epilogue_t::msg_type_c != msg_type::unaligned_2d) {
// if (epilogue_t::msg_type_c == msg_type::block_2d) {
// implementable &=
// kernel::block_2d<gpu_arch::XeHpc, dtype_c>::check_tensor(
// (uint64_t)(args.matC_base.base),
// args.matrix_n,
// args.matrix_m * args.batch_size,
// args.matC_ld);
// } else {
// implementable &=
// kernel::general_1d<gpu_arch::XeHpc, dtype_c>::check_alignment(
// args.matC_base.base, args.matC_ld);
// }
// }

return implementable;
}
Expand Down
56 changes: 28 additions & 28 deletions examples/07_multi_layer_perceptron/multi_layer_perceptron.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,20 +409,20 @@ class multi_layer_perceptron_t {
args.matW_base.base, args.matW_ld);
}
}
if (epilogue_layer1_t::msg_type_c != msg_type::unaligned_2d) {
if (epilogue_layer1_t::msg_type_c == msg_type::block_2d) {
implementable &=
kernel::block_2d<gpu_arch::XeHpc, dtype_b>::check_tensor(
(uint64_t)(args.matB_base.base),
args.matrix_n_layer1,
args.matrix_m_layer1,
args.matB_ld);
} else {
implementable &=
kernel::general_1d<gpu_arch::XeHpc, dtype_b>::check_alignment(
args.matB_base.base, args.matB_ld);
}
}
// if (epilogue_layer1_t::msg_type_c != msg_type::unaligned_2d) {
// if (epilogue_layer1_t::msg_type_c == msg_type::block_2d) {
// implementable &=
// kernel::block_2d<gpu_arch::XeHpc, dtype_b>::check_tensor(
// (uint64_t)(args.matB_base.base),
// args.matrix_n_layer1,
// args.matrix_m_layer1,
// args.matB_ld);
// } else {
// implementable &=
// kernel::general_1d<gpu_arch::XeHpc, dtype_b>::check_alignment(
// args.matB_base.base, args.matB_ld);
// }
// }
if (gemm_layer2_t::msg_type_a != msg_type::unaligned_2d) {
if (gemm_layer2_t::msg_type_a == msg_type::block_2d) {
implementable &=
Expand Down Expand Up @@ -451,20 +451,20 @@ class multi_layer_perceptron_t {
args.matV_base.base, args.matV_ld);
}
}
if (epilogue_layer2_t::msg_type_c != msg_type::unaligned_2d) {
if (epilogue_layer2_t::msg_type_c == msg_type::block_2d) {
implementable &=
kernel::block_2d<gpu_arch::XeHpc, dtype_c>::check_tensor(
(uint64_t)(args.matC_base.base),
args.matrix_n_layer2,
args.matrix_m_layer2,
args.matC_ld);
} else {
implementable &=
kernel::general_1d<gpu_arch::XeHpc, dtype_c>::check_alignment(
args.matC_base.base, args.matC_ld);
}
}
// if (epilogue_layer2_t::msg_type_c != msg_type::unaligned_2d) {
// if (epilogue_layer2_t::msg_type_c == msg_type::block_2d) {
// implementable &=
// kernel::block_2d<gpu_arch::XeHpc, dtype_c>::check_tensor(
// (uint64_t)(args.matC_base.base),
// args.matrix_n_layer2,
// args.matrix_m_layer2,
// args.matC_ld);
// } else {
// implementable &=
// kernel::general_1d<gpu_arch::XeHpc, dtype_c>::check_alignment(
// args.matC_base.base, args.matC_ld);
// }
// }

return implementable;
}
Expand Down
11 changes: 7 additions & 4 deletions examples/08_scaled_dot_product_attention/softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,21 @@ struct xetla_softmax_fwd_t {
using softmax_tile_desc_t = subgroup::
tile_desc_t<SIMD, block_height, SIMD, block_height, reg_layout::tiled>;
using softmax_load_t = subgroup::tile_t<dtype_in, softmax_tile_desc_t>;
using mem_desc_in_t = mem_desc_t<dtype_in, mem_layout::row_major, mem_space_in>;
using softmax_load_payload_t = subgroup::mem_payload_t<
mem_desc_t<dtype_in, mem_layout::row_major, mem_space_in>,
mem_desc_in_t,
softmax_tile_desc_t,
subgroup::msg_type_v<softmax_tile_desc_t, mem_space_in>,
subgroup::msg_type_v<softmax_tile_desc_t, mem_desc_in_t>,
arch_tag>;

// this tile will store the softmax result to global memory
using softmax_store_t = subgroup::tile_t<dtype_out, softmax_tile_desc_t>;
using mem_desc_out_t =
mem_desc_t<dtype_out, mem_layout::row_major, mem_space_out>;
using softmax_store_payload_t = subgroup::mem_payload_t<
mem_desc_t<dtype_out, mem_layout::row_major, mem_space_out>,
mem_desc_out_t,
softmax_tile_desc_t,
subgroup::msg_type_v<softmax_tile_desc_t, mem_space_out>,
subgroup::msg_type_v<softmax_tile_desc_t, mem_desc_out_t>,
arch_tag>;

struct arguments_t {
Expand Down
2 changes: 1 addition & 1 deletion examples/09_gate_recurrent_unit/kernel_func.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ struct gru_layer {
using mat_hidden_payload_t = mem_payload_t<
mem_desc_a_t,
matC_tile_desc_t,
msg_type_v<matC_tile_desc_t, mem_loc_input>,
msg_type_v<matC_tile_desc_t, mem_desc_a_t>,
gpu_arch::XeHpc>;
using matC_payload_t = mem_payload_t<
mem_desc_c_t,
Expand Down
63 changes: 52 additions & 11 deletions include/common/core/arch_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ struct load_store_attr_t {
static constexpr bool has_hw_block_2d = false;
};

template <>
struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
/// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490
Comment on lines -34 to -36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are multiple places with the gfxspecs link. I think they are helpful for internal developers. If they violate any company policies, they should be removed all at once in a separate PR.

template <msg_type message_type, gpu_arch arg_tag>
struct xe_plus_load_store_attr_t {
static constexpr bool has_hw_block_2d = true;
static constexpr uint32_t max_load_height_in_elem = 32;
static constexpr uint32_t max_load_width_in_bytes = 64;
Expand All @@ -55,10 +54,9 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {

template <msg_type message_type, gpu_arch arg_tag>
struct client_load_store_attr_base_t {
/// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490
static constexpr bool has_hw_block_2d = false;
static constexpr uint32_t max_load_height_in_elem = 32;
static constexpr uint32_t max_load_width_in_bytes = 64;
static constexpr uint32_t max_load_height_in_elem = 0;
static constexpr uint32_t max_load_width_in_bytes = 0;
static constexpr uint32_t max_trans_load_width_in_bytes = 32;
static constexpr uint32_t max_vnni_load_width_in_elems = 16;
static constexpr uint32_t min_vnni_load_height_in_bytes = 4;
Expand Down Expand Up @@ -87,21 +85,40 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeLpg>
msg_type::block_2d,
gpu_arch::XeLpg> {};

template <>
struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc>
: public xe_plus_load_store_attr_base_t<
msg_type::block_2d,
gpu_arch::XeHpc> {};

template <>
struct load_store_attr_t<msg_type::block_2d, gpu_arch::Xe2>
: public xe_plus_load_store_attr_base_t<
msg_type::block_2d,
gpu_arch::Xe2> {};

template <gpu_arch arch_tag>
inline constexpr bool arch_has_2d_load_store =
load_store_attr_t<msg_type::block_2d, arch_tag>::has_hw_block_2d;

template <gpu_arch arch_tag>
struct load_store_attr_t<msg_type::block_1d, arch_tag> {
static constexpr uint32_t max_load_vec_len = 32;
static constexpr uint32_t max_store_vec_len = 32;
static constexpr uint32_t max_load_vec_len = 256;
static constexpr uint32_t max_store_vec_len = 256;
static constexpr uint32_t max_prefetch_vec_len = 32;
};

template <>
struct load_store_attr_t<msg_type::block_1d, gpu_arch::XeHpc> {
static constexpr uint32_t max_load_vec_len = 64;
static constexpr uint32_t max_store_vec_len = 64;
static constexpr uint32_t max_load_vec_len = 512;
static constexpr uint32_t max_store_vec_len = 512;
static constexpr uint32_t max_prefetch_vec_len = 64;
};

template <>
struct load_store_attr_t<msg_type::block_1d, gpu_arch::Xe2> {
static constexpr uint32_t max_load_vec_len = 512;
static constexpr uint32_t max_store_vec_len = 512;
static constexpr uint32_t max_prefetch_vec_len = 64;
};

Expand Down Expand Up @@ -129,6 +146,11 @@ struct dpas_attr_t<gpu_arch::XeHpg> : public dpas_attr_base_t {
static constexpr uint32_t n_fixed_limit = 8;
};

template <>
struct dpas_attr_t<gpu_arch::Xe2> : public dpas_attr_t<gpu_arch::XeHpc> {
static constexpr uint32_t systolic_depth = 4;
};

template <gpu_arch arch_tag>
inline constexpr bool arch_has_xmx = dpas_attr_t<arch_tag>::has_xmx;

Expand Down Expand Up @@ -162,6 +184,10 @@ template <>
struct register_bytes_t<gpu_arch::XeLpg> {
static constexpr uint32_t reg_in_bytes = 32;
};
template <>
struct register_bytes_t<gpu_arch::Xe2> {
static constexpr uint32_t reg_in_bytes = 64;
};

template <grf_mode grf_num_mode, gpu_arch arch_tag>
struct register_attr_t {
Expand Down Expand Up @@ -236,10 +262,25 @@ struct arch_attr_t<gpu_arch::XeLpg> {

using dpas_attr = dpas_attr_t<gpu_arch::XeLpg>;

static constexpr uint32_t max_wg_num = 64;
static constexpr uint32_t max_wg_num = 16;
static constexpr uint32_t local_mem_size = 64 * 1024;
};

template <>
struct arch_attr_t<gpu_arch::Xe2> {
template <msg_type message_type = msg_type::block_2d>
using load_store_attr = load_store_attr_t<message_type, gpu_arch::Xe2>;

template <grf_mode grf_num_mode = grf_mode::double_grf>
using register_attr = register_attr_t<grf_num_mode, gpu_arch::Xe2>;

using dpas_attr = dpas_attr_t<gpu_arch::Xe2>;

static constexpr uint32_t max_wg_num = 16;
static constexpr uint32_t local_mem_size = 128 * 1024;
};


/// @} xetla_core_arch_config

} // namespace gpu::xetla
5 changes: 2 additions & 3 deletions include/common/core/base_consts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@

namespace gpu::xetla {

/// @addtogroup xetla_core_base_types
/// @addtogroup xetla_core_base_consts
/// @{

/// @} xetla_core_base_types
/// @} xetla_core_base_consts

} // namespace gpu::xetla
40 changes: 40 additions & 0 deletions include/common/core/base_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,32 @@ using fp16 = sycl::half;
///
using tf32 = sycl::ext::intel::experimental::esimd::tfloat32;

/// @brief xetla 4bits data packed as 8bits data type.
/// 2 4bit data pack to one byte
struct int4x2 {
uint8_t data;

operator uint8_t() const {
return data;
}
int4x2(uint8_t val) {
data = val;
}
};

/// @brief xetla 4bits data packed as 32bits data type.
/// 8 4bit data pack to 4 bytes
struct int4x8 {
uint32_t data;

operator uint32_t() const {
return data;
}
int4x8(uint32_t val) {
data = val;
}
};

/// @brief mx_fp4(E2M1) data packed as 8bits data type.
struct mx_fp4 {
uint8_t data;
Expand Down Expand Up @@ -89,6 +115,8 @@ template <typename T>
struct is_internal_type {
static constexpr bool value = std::is_same<remove_const_t<T>, bf16>::value ||
std::is_same<remove_const_t<T>, tf32>::value ||
std::is_same<remove_const_t<T>, int4x2>::value ||
std::is_same<remove_const_t<T>, int4x8>::value ||
std::is_same<remove_const_t<T>, mx_fp4>::value;
};
template <typename T>
Expand Down Expand Up @@ -137,6 +165,18 @@ struct native_type<mx_fp4> {
using type = uint8_t;
};

/// @brief Set uint8_t as the native data type of int4x2.
template <>
struct native_type<int4x2> {
using type = uint8_t;
};

/// @brief Set uint8_t as the native data type of int4x8.
template <>
struct native_type<int4x8> {
using type = uint32_t;
};

/// @brief Return the native data type of T
template <typename T>
using native_type_t = typename native_type<T>::type;
Expand Down
11 changes: 10 additions & 1 deletion include/common/core/common_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,18 @@
#include <cstdint>

namespace gpu::xetla {
enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2 };
enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2, Xe2 = 3 };

enum class grf_mode : uint8_t { normal = 0, double_grf = 1 };

enum class mem_layout : uint8_t { row_major = 0, col_major = 1 };

enum class quant_mode : uint8_t { S4_ASYM = 0, S4_FULLRANGE_NO_ZP = 1 };

struct quant_info {
quant_mode quant_mode;
uint32_t dequant_s;
mem_layout weight_mem_layout;
};

} // namespace gpu::xetla
Loading