Skip to content

Commit

Permalink
[c10/metal] Add a vectype variant for short/int/long (pytorch#1…
Browse files Browse the repository at this point in the history
…45430)

Some of the kernels (exp_complex/atan_complex) need the specialization.

Pull Request resolved: pytorch#145430
Approved by: https://github.com/malfet, https://github.com/jansel
  • Loading branch information
dcci authored and pytorchmergebot committed Jan 23, 2025
1 parent c581981 commit f56c638
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions c10/metal/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,28 @@ struct vectypes<bfloat> {
using type2 = bfloat2;
};
#endif

template <>
struct vectypes<short> {
using type4 = short4;
using type3 = short3;
using type2 = short2;
};

template <>
struct vectypes<int> {
using type4 = int4;
using type3 = int3;
using type2 = int2;
};

template <>
struct vectypes<long> {
using type4 = short4;
using type3 = short3;
using type2 = short2;
};

} // namespace detail

template <typename T>
Expand Down

0 comments on commit f56c638

Please sign in to comment.