Skip to content

Commit

Permalink
Merge branch 'dev/3.0' into feature/xpu_backend
Browse files Browse the repository at this point in the history
  • Loading branch information
xhuohai authored Sep 30, 2024
2 parents c4e916b + d058651 commit 9939e98
Show file tree
Hide file tree
Showing 10 changed files with 594 additions and 171 deletions.
70 changes: 70 additions & 0 deletions src/Native/include/nncase/ntt/arch/x86_64/ukernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,74 @@ class u_pack<M, N, MStrides, true, float, vector<float, 8>> {
template <reduce_op Op, class T> struct u_reduce_policy<Op, T, true> {
static constexpr size_t unroll = 8;
};

template <>
struct u_matmul_policy<mamtul_pack_kind::no_pack, float, float, float, true> {
static constexpr size_t m0_tile = 1;
static constexpr size_t n0_tile = 1;
static constexpr size_t m0_subtile = 0;
};

// Pack M
template <>
struct u_matmul_policy<mamtul_pack_kind::pack_m, vector<float, 8>, float,
vector<float, 8>, true> {
static constexpr size_t m0_tile = 2;
static constexpr size_t n0_tile = 4;
static constexpr size_t m0_subtile = 0;
};

// Pack K
template <>
struct u_matmul_policy<mamtul_pack_kind::pack_k, vector<float, 8>,
vector<float, 8>, float, true> {
static constexpr size_t m0_tile = 2;
static constexpr size_t n0_tile = 2;
static constexpr size_t m0_subtile = 0;
};

// Pack N
template <>
struct u_matmul_policy<mamtul_pack_kind::pack_n, float, vector<float, 8>,
vector<float, 8>, true> {
static constexpr size_t m0_tile = 4;
static constexpr size_t n0_tile = 2;
static constexpr size_t m0_subtile = 0;
};

// Pack MN
template <>
struct u_matmul_policy<mamtul_pack_kind::pack_mn, vector<float, 8>,
vector<float, 8>, vector<float, 8, 8>, true> {
static constexpr size_t m0_tile = 1;
static constexpr size_t n0_tile = 2;
static constexpr size_t m0_subtile = 4;
};

// Pack MK
template <>
struct u_matmul_policy<mamtul_pack_kind::pack_mk, vector<float, 8, 8>,
vector<float, 8>, vector<float, 8>, true> {
static constexpr size_t m0_tile = 1;
static constexpr size_t n0_tile = 1;
static constexpr size_t m0_subtile = 0;
};

// Pack KN
template <>
struct u_matmul_policy<mamtul_pack_kind::pack_kn, vector<float, 8>,
vector<float, 8, 8>, vector<float, 8>, true> {
static constexpr size_t m0_tile = 4;
static constexpr size_t n0_tile = 2;
static constexpr size_t m0_subtile = 0;
};

// Pack MKN
template <>
struct u_matmul_policy<mamtul_pack_kind::pack_mkn, vector<float, 8, 8>,
vector<float, 8, 8>, vector<float, 8, 8>, true> {
static constexpr size_t m0_tile = 1;
static constexpr size_t n0_tile = 2;
static constexpr size_t m0_subtile = 4;
};
} // namespace nncase::ntt::ukernels
4 changes: 4 additions & 0 deletions src/Native/include/nncase/ntt/detail/shape_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ template <class Shape> class shape_storage {

template <size_t... Dims> class shape_storage<fixed_shape<Dims...>> {
public:
constexpr shape_storage(fixed_shape<Dims...> = {}) noexcept {};

static constexpr size_t rank() noexcept { return sizeof...(Dims); }
static constexpr auto shape() noexcept { return fixed_shape<Dims...>{}; }
};
Expand All @@ -47,6 +49,8 @@ template <class Strides> class strides_storage {

template <size_t... Dims> class strides_storage<fixed_strides<Dims...>> {
public:
constexpr strides_storage(fixed_strides<Dims...> = {}) noexcept {};

static constexpr auto strides() noexcept {
return fixed_strides<Dims...>{};
}
Expand Down
Loading

0 comments on commit 9939e98

Please sign in to comment.