diff --git a/python/src/llvm.cc b/python/src/llvm.cc
index 21ba2bd5aba4..e4be9846bcc4 100644
--- a/python/src/llvm.cc
+++ b/python/src/llvm.cc
@@ -580,6 +580,24 @@ void init_triton_llvm(py::module &&m) {
       if (f.second)
         res.insert(f.first().str());
     }
+
+    // Likely something went wrong with the LLVM feature detection.
+    if (!res.size()) {
+      std::string triple = llvm::sys::getProcessTriple();
+      // e.g. arm64-apple-darwin24.1.0
+      //      ^^^^^
+      std::size_t pos = triple.find('-');
+      if (pos == std::string::npos) {
+        return res;
+      }
+
+      std::string arch = triple.substr(0, pos);
+      if (arch == "aarch64" || arch == "arm64") {
+        // Safe because NEON is a mandatory feature for aarch64.
+        res.insert("neon"); // For math tests
+      }
+    }
+
     return res;
   });
 }
diff --git a/python/test/unit/cpu/test_math.py b/python/test/unit/cpu/test_math.py
index 958913e7f9f1..1fd443db967a 100644
--- a/python/test/unit/cpu/test_math.py
+++ b/python/test/unit/cpu/test_math.py
@@ -5,10 +5,23 @@
 
 import triton
 import triton.language as tl
+from triton._C.libtriton import llvm
 from triton.language.extra import libdevice
 from itertools import chain, product
 
 
+def get_native_vector_size_in_bits():
+    """
+    Returns the native vector size of the CPU.
+    Assuming x86 always uses "auto dispatch" with 512-bit vectors for Sleef.
+    """
+    cpu_features = llvm.get_cpu_features()
+    # TODO support for arm sve w/ VLA
+    if "neon" in cpu_features:
+        return 128
+    return 512
+
+
 def is_interpreter():
     return os.environ.get('TRITON_INTERPRET', '0') == '1'
 
@@ -34,9 +47,13 @@ def check_num_vec_calls(meta, vec_lib, dtype_str, size, is_always_extern=False):
     # FP16 and BF16 are cast to FP32 for math ops
     elem_size = 8 if dtype_str == "float64" else 4
     data_size = size * elem_size
-    if data_size > 64:
-        num_vec_calls = data_size // 64
-    elif data_size >= 16:
+
+    vec_size = get_native_vector_size_in_bits() / 8  # bytes
+    # 128-bit vector is the smallest supported by Sleef for both x86 and arm
+    smallest_vec_size = 128 / 8  # bytes
+    if data_size > vec_size:
+        num_vec_calls = data_size // vec_size
+    elif data_size >= smallest_vec_size:
         num_vec_calls = 1
     else:
         num_vec_calls = 1 if is_always_extern else 0
diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py
index 5aabcc051b91..c4b5e6ecd918 100644
--- a/third_party/cpu/backend/compiler.py
+++ b/third_party/cpu/backend/compiler.py
@@ -162,7 +162,8 @@ def make_tttcir(self, mod, metadata, opt):
         pm.enable_debug()
         cpu.passes.ttcpuir.add_optimize_masks(pm)
         passes.common.add_canonicalizer(pm)
-        convert_bf16_dot_product = self.cpu_arch == "aarch64" and 'fp-armv8' in self.cpu_features and 'neon' in self.cpu_features
+        convert_bf16_dot_product = ((self.cpu_arch == "aarch64" or self.cpu_arch == "armv8")
+                                    and 'fp-armv8' in self.cpu_features and 'neon' in self.cpu_features)
         if convert_bf16_dot_product:
             use_horizontal_sum = os.getenv("TRITON_CPU_DOT_PROD_HORIZ_SUM", "1") == "1"
             cpu.passes.ttcpuir.add_convert_dot_product(pm, use_horizontal_sum)
@@ -215,7 +216,7 @@ def make_llir(self, src, metadata, options):
             VecLib.libmvec: {"avx512f"},
         }
         if (vec_lib := options.get_vec_lib()) and vec_lib_requirements[vec_lib] & self.cpu_features:
-            cpu.passes.ttcpuir.add_math_to_vec_lib(pm, vec_lib)
+            cpu.passes.ttcpuir.add_math_to_vec_lib(pm, vec_lib, self.cpu_features)
 
         passes.convert.add_math_to_llvmir(pm)
         cpu.passes.ttcpuir.add_math_to_libm(pm)
diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h
index 6e9892d00206..cc29821c580c 100644
--- a/third_party/cpu/include/TritonCPUToLLVM/Passes.h
+++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.h
@@ -32,7 +32,8 @@ std::unique_ptr<OperationPass<triton::FuncOp>> createLowerMultiReductionPass();
 std::unique_ptr<OperationPass<ModuleOp>> createAtomicOpsToLLVMPass();
 std::unique_ptr<OperationPass<ModuleOp>> createDebugOpsToLLVMPass();
 std::unique_ptr<OperationPass<ModuleOp>>
-createMathToVecLibPass(VecLib lib = VecLib::Sleef);
+createMathToVecLibPass(VecLib lib = VecLib::Sleef,
+                       std::set<std::string> cpu_features = {});
 
 #define GEN_PASS_REGISTRATION
 #include "cpu/include/TritonCPUToLLVM/Passes.h.inc"
diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp
index b68d5a7473d0..2b1877c1c17b 100644
--- a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp
+++ b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp
@@ -56,14 +56,17 @@ template <typename OpT> struct VecOpToFp32 : public OpRewritePattern<OpT> {
 };
 
 // Decompose vector operation to single-dimensional vector operations
-// with a native AVX512 vector size.
+// with a AVX512 for x86 or NEON for ARM.
 template <typename OpT>
 struct DecomposeToNativeVecs : public OpRewritePattern<OpT> {
 public:
   using OpRewritePattern<OpT>::OpRewritePattern;
+  // CPU SIMD vector size in bits
+  size_t vec_bits;
 
-  DecomposeToNativeVecs(MLIRContext *context)
-      : OpRewritePattern<OpT>(context) {}
+  DecomposeToNativeVecs(MLIRContext *context,
+                        size_t native_vec_size_in_bits = 512)
+      : OpRewritePattern<OpT>(context), vec_bits(native_vec_size_in_bits) {}
 
   LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const {
     Location loc = op.getLoc();
@@ -83,7 +86,7 @@ struct DecomposeToNativeVecs : public OpRewritePattern<OpT> {
     // vector size.
     auto shape = vecTy.getShape();
     SmallVector<int64_t> newShape(1, 1);
-    int64_t elemsPerVec = 512 / elemTy.getIntOrFloatBitWidth();
+    int64_t elemsPerVec = vec_bits / elemTy.getIntOrFloatBitWidth();
     for (int64_t i = shape.size() - 1; i >= 0; --i) {
       int64_t size = shape[i];
       if (newShape.size() > 1) {
@@ -330,9 +333,11 @@ struct ExternElementwiseOpConversion
 
 template <typename OpTy>
 void populatePatternsForOp(RewritePatternSet &patterns,
-                           GetVecFnNameFn getVecFnName) {
+                           GetVecFnNameFn getVecFnName,
+                           size_t vec_size_in_bits = 512) {
   patterns.add<VecOpToFp32<OpTy>>(patterns.getContext());
-  patterns.add<DecomposeToNativeVecs<OpTy>>(patterns.getContext());
+  patterns.add<DecomposeToNativeVecs<OpTy>>(patterns.getContext(),
+                                            vec_size_in_bits);
   patterns.add<VecOpToVecLibConversion<OpTy>>(patterns.getContext(),
                                               getVecFnName);
 }
@@ -340,8 +345,27 @@ void populatePatternsForOp(RewritePatternSet &patterns,
 struct MathToVecLibPass
     : public mlir::triton::cpu::impl::MathToVecLibBase<MathToVecLibPass> {
   MathToVecLibPass() = default;
+  size_t vec_size_in_bits;
 
-  explicit MathToVecLibPass(VecLib lib) { this->lib = lib; }
+  explicit MathToVecLibPass(VecLib lib, std::set<std::string> cpu_features) {
+    this->lib = lib;
+    update_vec_size(cpu_features);
+  }
+
+  void update_vec_size(std::set<std::string> &cpu_features) {
+    // TODO:
+    //  Refactor this as an independent function.
+    //  And improve this to support other x86 SIMD ISAs and also for arm SVE
+    //  (VLA)
+    vec_size_in_bits = 512;
+    for (auto feature : cpu_features) {
+      // Arm NEON is fixed 128-bit SIMD ISA.
+      if (feature == "neon") {
+        vec_size_in_bits = 128;
+        break;
+      }
+    }
+  }
 
   void runOnOperation() override {
     Operation *op = getOperation();
@@ -356,20 +380,20 @@ struct MathToVecLibPass
     }
     case VecLib::Sleef: {
       populateCommonPatterns<SleefNameGenerator>(patterns);
-      populatePatternsForOp<math::ExpM1Op>(patterns,
-                                           SleefNameGenerator("expm1"));
+      populatePatternsForOp<math::ExpM1Op>(
+          patterns, SleefNameGenerator("expm1"), vec_size_in_bits);
       populatePatternsForOp<math::FloorOp>(
-          patterns, SleefNameGenerator("floor", /*ulp=*/0));
+          patterns, SleefNameGenerator("floor", /*ulp=*/0), vec_size_in_bits);
       populatePatternsForOp<math::SqrtOp>(
-          patterns, SleefNameGenerator("sqrt", /*ulp=*/5));
+          patterns, SleefNameGenerator("sqrt", /*ulp=*/5), vec_size_in_bits);
       populatePatternsForOp<math::TruncOp>(
-          patterns, SleefNameGenerator("trunc", /*ulp=*/0));
+          patterns, SleefNameGenerator("trunc", /*ulp=*/0), vec_size_in_bits);
       break;
     }
     }
 
     patterns.add<DecomposeToNativeVecs<ExternElementwiseOp>>(
-        patterns.getContext());
+        patterns.getContext(), vec_size_in_bits);
     patterns.add<PadSmallVecsForSleef>(patterns.getContext());
     patterns.add<ExternElementwiseOpConversion>(patterns.getContext());
 
@@ -379,26 +403,46 @@ struct MathToVecLibPass
 
   template <typename VecFnNameGenerator>
   void populateCommonPatterns(RewritePatternSet &patterns) const {
-    populatePatternsForOp<math::AcosOp>(patterns, VecFnNameGenerator("acos"));
-    populatePatternsForOp<math::AcoshOp>(patterns, VecFnNameGenerator("acosh"));
-    populatePatternsForOp<math::AsinOp>(patterns, VecFnNameGenerator("asin"));
-    populatePatternsForOp<math::AsinhOp>(patterns, VecFnNameGenerator("asinh"));
-    populatePatternsForOp<math::AtanOp>(patterns, VecFnNameGenerator("atan"));
-    populatePatternsForOp<math::AtanhOp>(patterns, VecFnNameGenerator("atanh"));
-    populatePatternsForOp<math::CbrtOp>(patterns, VecFnNameGenerator("cbrt"));
-    populatePatternsForOp<math::CosOp>(patterns, VecFnNameGenerator("cos"));
-    populatePatternsForOp<math::CoshOp>(patterns, VecFnNameGenerator("cosh"));
-    populatePatternsForOp<math::ErfOp>(patterns, VecFnNameGenerator("erf"));
-    populatePatternsForOp<math::ExpOp>(patterns, VecFnNameGenerator("exp"));
-    populatePatternsForOp<math::Exp2Op>(patterns, VecFnNameGenerator("exp2"));
-    populatePatternsForOp<math::LogOp>(patterns, VecFnNameGenerator("log"));
-    populatePatternsForOp<math::Log2Op>(patterns, VecFnNameGenerator("log2"));
-    populatePatternsForOp<math::Log10Op>(patterns, VecFnNameGenerator("log10"));
-    populatePatternsForOp<math::Log1pOp>(patterns, VecFnNameGenerator("log1p"));
-    populatePatternsForOp<math::SinOp>(patterns, VecFnNameGenerator("sin"));
-    populatePatternsForOp<math::SinhOp>(patterns, VecFnNameGenerator("sinh"));
-    populatePatternsForOp<math::TanOp>(patterns, VecFnNameGenerator("tan"));
-    populatePatternsForOp<math::TanhOp>(patterns, VecFnNameGenerator("tanh"));
+    populatePatternsForOp<math::AcosOp>(patterns, VecFnNameGenerator("acos"),
+                                        vec_size_in_bits);
+    populatePatternsForOp<math::AcoshOp>(patterns, VecFnNameGenerator("acosh"),
+                                         vec_size_in_bits);
+    populatePatternsForOp<math::AsinOp>(patterns, VecFnNameGenerator("asin"),
+                                        vec_size_in_bits);
+    populatePatternsForOp<math::AsinhOp>(patterns, VecFnNameGenerator("asinh"),
+                                         vec_size_in_bits);
+    populatePatternsForOp<math::AtanOp>(patterns, VecFnNameGenerator("atan"),
+                                        vec_size_in_bits);
+    populatePatternsForOp<math::AtanhOp>(patterns, VecFnNameGenerator("atanh"),
+                                         vec_size_in_bits);
+    populatePatternsForOp<math::CbrtOp>(patterns, VecFnNameGenerator("cbrt"),
+                                        vec_size_in_bits);
+    populatePatternsForOp<math::CosOp>(patterns, VecFnNameGenerator("cos"),
+                                       vec_size_in_bits);
+    populatePatternsForOp<math::CoshOp>(patterns, VecFnNameGenerator("cosh"),
+                                        vec_size_in_bits);
+    populatePatternsForOp<math::ErfOp>(patterns, VecFnNameGenerator("erf"),
+                                       vec_size_in_bits);
+    populatePatternsForOp<math::ExpOp>(patterns, VecFnNameGenerator("exp"),
+                                       vec_size_in_bits);
+    populatePatternsForOp<math::Exp2Op>(patterns, VecFnNameGenerator("exp2"),
+                                        vec_size_in_bits);
+    populatePatternsForOp<math::LogOp>(patterns, VecFnNameGenerator("log"),
+                                       vec_size_in_bits);
+    populatePatternsForOp<math::Log2Op>(patterns, VecFnNameGenerator("log2"),
+                                        vec_size_in_bits);
+    populatePatternsForOp<math::Log10Op>(patterns, VecFnNameGenerator("log10"),
+                                         vec_size_in_bits);
+    populatePatternsForOp<math::Log1pOp>(patterns, VecFnNameGenerator("log1p"),
+                                         vec_size_in_bits);
+    populatePatternsForOp<math::SinOp>(patterns, VecFnNameGenerator("sin"),
+                                       vec_size_in_bits);
+    populatePatternsForOp<math::SinhOp>(patterns, VecFnNameGenerator("sinh"),
+                                        vec_size_in_bits);
+    populatePatternsForOp<math::TanOp>(patterns, VecFnNameGenerator("tan"),
+                                       vec_size_in_bits);
+    populatePatternsForOp<math::TanhOp>(patterns, VecFnNameGenerator("tanh"),
+                                        vec_size_in_bits);
   }
 };
 
@@ -408,8 +452,9 @@ namespace mlir {
 namespace triton {
 namespace cpu {
 
-std::unique_ptr<OperationPass<ModuleOp>> createMathToVecLibPass(VecLib lib) {
-  return std::make_unique<MathToVecLibPass>(lib);
+std::unique_ptr<OperationPass<ModuleOp>>
+createMathToVecLibPass(VecLib lib, std::set<std::string> cpu_features) {
+  return std::make_unique<MathToVecLibPass>(lib, cpu_features);
 }
 
 } // namespace cpu
diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc
index 3c3555d3c983..a412190bbcf8 100644
--- a/third_party/cpu/triton_cpu.cc
+++ b/third_party/cpu/triton_cpu.cc
@@ -145,8 +145,9 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) {
   m.def("add_memref_to_llvmir", [](mlir::PassManager &pm) {
     pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass());
   });
-  m.def("add_math_to_vec_lib", [](mlir::PassManager &pm, cpu::VecLib lib) {
-    pm.addPass(mlir::triton::cpu::createMathToVecLibPass(lib));
+  m.def("add_math_to_vec_lib", [](mlir::PassManager &pm, cpu::VecLib lib,
+                                  std::set<std::string> cpu_features) {
+    pm.addPass(mlir::triton::cpu::createMathToVecLibPass(lib, cpu_features));
   });
   m.def("add_math_to_libm", [](mlir::PassManager &pm) {
     pm.addPass(mlir::createConvertMathToLibmPass());