From 9b8cdf2b9694a93791906b9b22079a386d53510b Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Tue, 30 Sep 2025 15:51:23 +0200 Subject: [PATCH 1/3] Cortex_m backend: Simplify add + linear fusion passes Reuses the FoldAndAnnotateQParamsPass from the Arm backend to greatly simplify the logic for fusing the ops. Additionally updates the linear kernel to be numerically correct and computes the kernel_sum aot in the quantized_linear_fusion pass. Note that since this replaces the bias node it typically causes no extra memory usage. Updates the Linear tests to mirror this, including removing the various matmul tests. Since the linear is handled as a separate op rather than a particular type of matmul these tests are not related anymore. Removes unnecessary stub definitions in operators.py, operators.yaml and op_quantized_linear.cpp Leaving a few TODO:s since the patch is large already. Signed-off-by: Adrian Lundell Change-Id: I194228ee3ae4b64a92f3f818afb2e045cc3acf91 --- .../ops/cmsis_scratch_buffer_context.h | 187 ----- backends/cortex_m/ops/cortex_m_ops_common.h | 1 + backends/cortex_m/ops/op_quantized_linear.cpp | 195 ++--- backends/cortex_m/ops/operators.py | 254 ++----- backends/cortex_m/ops/operators.yaml | 10 +- .../cortex_m/passes/cortex_m_pass_manager.py | 8 +- .../passes/quantized_linear_fusion_pass.py | 703 +++--------------- .../passes/quantized_op_fusion_pass.py | 284 ++----- backends/cortex_m/test/ops/test_linear.py | 141 ++-- 9 files changed, 353 insertions(+), 1430 deletions(-) delete mode 100644 backends/cortex_m/ops/cmsis_scratch_buffer_context.h diff --git a/backends/cortex_m/ops/cmsis_scratch_buffer_context.h b/backends/cortex_m/ops/cmsis_scratch_buffer_context.h deleted file mode 100644 index 4b9fdaebdf7..00000000000 --- a/backends/cortex_m/ops/cmsis_scratch_buffer_context.h +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#pragma once - -#include "cortex_m_ops_common.h" -extern "C" { -#include "arm_nnfunctions.h" -} - -namespace cortex_m { -namespace native { - -// During AOT phase, quantized_linear_fusion_pass allocates total buffer -// and passes in as 'Tensor'. (Total buffer = 8-byte header + x bytes) -// ┌─────────────────┬─────────────────────────────────────┐ -// │ KernelSum Header│ CMSIS Workspace │ -// │ (8 bytes) │ (x bytes) │ -// └─────────────────┴─────────────────────────────────────┘ -// │ │ -// │ └─> Passed to CMSIS API -// │ -// └─> State for kernel sum - -// C++ Runtime: -// ┌─────────────────┬─────────────────────────────────────┐ -// │ KernelSum Header│ CMSIS Workspace │ -// │ (8 bytes) │ (x bytes) │ -// └─────────────────┴─────────────────────────────────────┘ -// ^ ^ -// │ │ -// scratch_ptr cmsis_workspace_ptr -// │ │ -// ▼ ▼ -// arm_vector_sum_s8() writes kernel sums (with bias if avail): -// [sum₀+bias₀][sum₁+bias₁][sum₂+bias₂]...[sum_{n-1}+bias_{n-1}] -// (n * 4-byte int32_t values = x bytes) -// -// - n = out_features (number of output features) -// - x = n * 4 bytes (total CMSIS buffer size) -// - Total buffer = 8 + x bytes - -class CMSISScratchBufferContext final { - public: - CMSISScratchBufferContext( - Tensor& scratch_buffer, - const Tensor& weights, - const Tensor& weight_zero_point, - const torch::executor::optional& bias) - : scratch_ptr_(scratch_buffer.mutable_data_ptr()), - total_size_(scratch_buffer.size(0)), - base_ptr_(reinterpret_cast(scratch_ptr_)), - in_features_(weights.size(1)), - out_features_(weights.size(0)), - is_per_channel_(weight_zero_point.numel() > 1), - weight_data_offset_(calculate_offset(weights.const_data_ptr())), - weight_zp_data_offset_( - calculate_offset(weight_zero_point.const_data_ptr())), - bias_data_offset_( - bias.has_value() - ? calculate_offset(bias.value().const_data_ptr()) - : 0), - header_(reinterpret_cast(scratch_ptr_)), - cmsis_workspace_ptr_(scratch_ptr_ + KERNEL_SUM_HEADER_SIZE) { - cmsis_nn_dims filter_dims = {in_features_, 1, 1, out_features_}; - validate_size(filter_dims); - } - - cmsis_nn_context get_cmsis_ctx() const { - cmsis_nn_context ctx; - ET_CHECK_MSG( - reinterpret_cast(cmsis_workspace_ptr_) % 4 == 0, - "CMSIS workspace not 4-byte aligned"); - ctx.buf = cmsis_workspace_ptr_; - ctx.size = get_cmsis_workspace_size(); - return ctx; - } - - bool is_kernel_sum_updated() const { - return header_->updated; - } - - void compute_kernel_sums_if_needed() { - if (!header_->updated) { - arm_vector_sum_s8( - reinterpret_cast(cmsis_workspace_ptr_), - in_features_, - out_features_, - get_weight_data(), - get_weight_zp_data()[0], - 0, - get_bias_data()); - header_->updated = true; - ET_LOG( - Info, - "Computed kernel sums. [required_bytes : %d]", - header_->required_size); - } - } - - const int8_t* get_weight_data() const { - return reinterpret_cast(base_ptr_ + weight_data_offset_); - } - - const int32_t* get_weight_zp_data() const { - return reinterpret_cast(base_ptr_ + weight_zp_data_offset_); - } - - const int32_t* get_bias_data() const { - return bias_data_offset_ == 0 - ? nullptr - : reinterpret_cast(base_ptr_ + bias_data_offset_); - } - - bool is_per_channel_quant() const { - return is_per_channel_; - } - int32_t get_in_features() const { - return in_features_; - } - int32_t get_out_features() const { - return out_features_; - } - - private: - static constexpr size_t KERNEL_SUM_HEADER_SIZE = 8; - - // Header for kernel sum computation state only - struct KernelSumHeader { - bool updated = false; - int32_t required_size = 0; - }; - static_assert( - sizeof(KernelSumHeader) == KERNEL_SUM_HEADER_SIZE, - "KernelSumHeader must be exactly 8 bytes"); - - int8_t* scratch_ptr_; - size_t total_size_; - uint8_t* base_ptr_; - - // Context members - const int32_t in_features_; - const int32_t out_features_; - const bool is_per_channel_; - const uint32_t weight_data_offset_; - const uint32_t weight_zp_data_offset_; - const uint32_t bias_data_offset_; - - KernelSumHeader* header_; - int8_t* cmsis_workspace_ptr_; - - uint32_t calculate_offset(const void* ptr) const { - if (ptr == nullptr) - return 0; - - const uint8_t* ptr_bytes = reinterpret_cast(ptr); - ET_CHECK_MSG(ptr_bytes >= base_ptr_, "Pointer is before base address"); - - const std::ptrdiff_t offset = ptr_bytes - base_ptr_; - ET_CHECK_MSG( - offset >= 0 && offset <= UINT32_MAX, "Offset out of valid range"); - return static_cast(offset); - } - - size_t get_cmsis_workspace_size() const { - return total_size_ - KERNEL_SUM_HEADER_SIZE; - } - - void validate_size(const cmsis_nn_dims& filter_dims) const { - header_->required_size = - arm_fully_connected_s8_get_buffer_size(&filter_dims); - - ET_CHECK_MSG( - get_cmsis_workspace_size() >= - static_cast(header_->required_size), - "Scratch buffer size %zu insufficient for required size %d", - get_cmsis_workspace_size(), - header_->required_size); - } -}; - -} // namespace native -} // namespace cortex_m diff --git a/backends/cortex_m/ops/cortex_m_ops_common.h b/backends/cortex_m/ops/cortex_m_ops_common.h index c7e6cc8a389..10fad9b6a0b 100644 --- a/backends/cortex_m/ops/cortex_m_ops_common.h +++ b/backends/cortex_m/ops/cortex_m_ops_common.h @@ -22,6 +22,7 @@ using Tensor = torch::executor::Tensor; using ScalarType = executorch::aten::ScalarType; using Scalar = torch::executor::Scalar; using Error = executorch::runtime::Error; +using IntArrayRef = executorch::aten::ArrayRef; // From arm_nn_math_types.h #define ARM_NN_Q31_MAX ((int32_t)(0x7FFFFFFFL)) diff --git a/backends/cortex_m/ops/op_quantized_linear.cpp b/backends/cortex_m/ops/op_quantized_linear.cpp index d1ccb6d0d45..ea01a8f772b 100644 --- a/backends/cortex_m/ops/op_quantized_linear.cpp +++ b/backends/cortex_m/ops/op_quantized_linear.cpp @@ -1,12 +1,12 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -#include "cmsis_scratch_buffer_context.h" #include "cortex_m_ops_common.h" extern "C" { @@ -20,151 +20,90 @@ using KernelRuntimeContext = torch::executor::KernelRuntimeContext; Tensor& quantized_linear_out( KernelRuntimeContext& context, const Tensor& input, - const Scalar& input_zero_point, - const Scalar& input_multiplier, - const Scalar& input_shift, const Tensor& weights, - const Tensor& weight_zero_point, - const Tensor& weight_multiplier, - const Tensor& weight_shift, const torch::executor::optional& bias, - const Tensor& bias_multiplier, - const Tensor& bias_shift, - const Tensor& scratch_buffer, - const Scalar& output_zero_point, - const Scalar& in_features, - const Scalar& out_features, + const torch::executor::optional& kernel_sum, + const Scalar& input_offset, + const Scalar& filter_offset, + const Scalar& output_offset, + const IntArrayRef requantize_multipliers, + const IntArrayRef requantize_shifts, + const Scalar& activation_max, + const Scalar& activation_min, Tensor& out) { ET_LOG(Info, "quantized_linear_out: called"); - validate_cmsis_nn_tensor_requirements(input, weights, out); - - ET_CHECK_MSG( - scratch_buffer.scalar_type() == ScalarType::Char, - "Scratch buffer must be int8"); - - const int32_t batch_size = input.size(0); - const int32_t in_feat = static_cast(in_features.to()); - const int32_t out_feat = static_cast(out_features.to()); - const int32_t input_zp = static_cast(input_zero_point.to()); - const int32_t output_zp = - static_cast(output_zero_point.to()); - const bool is_per_channel = (weight_zero_point.numel() > 1); const int8_t* input_data = input.const_data_ptr(); const int8_t* weight_data = weights.const_data_ptr(); const int32_t* bias_data = bias.has_value() ? bias.value().const_data_ptr() : nullptr; + int32_t* kernel_sum_data = + kernel_sum.has_value() ? kernel_sum.value().data_ptr() : nullptr; int8_t* output_data = out.mutable_data_ptr(); - const int32_t* weight_zp_data = weight_zero_point.const_data_ptr(); - const int32_t* weight_mult_data = weight_multiplier.const_data_ptr(); - const int32_t* weight_shift_data = weight_shift.const_data_ptr(); - - if (!validate_per_channel_quant_params( - weight_mult_data, weight_shift_data, out_feat)) { - context.fail(Error::InvalidArgument); - return out; - } - - // Initialize scratch buffer context (validates early) - CMSISScratchBufferContext scratch_ctx( - const_cast(scratch_buffer), weights, weight_zero_point, bias); - scratch_ctx.compute_kernel_sums_if_needed(); - cmsis_nn_context ctx = scratch_ctx.get_cmsis_ctx(); + cmsis_nn_context ctx; + ctx.size = 2; + ctx.buf = kernel_sum_data; // Setup CMSIS-NN parameters cmsis_nn_fc_params fc_params; - fc_params.input_offset = -input_zp; - fc_params.output_offset = output_zp; - fc_params.activation.min = std::numeric_limits::min(); - fc_params.activation.max = std::numeric_limits::max(); - - cmsis_nn_dims input_dims = {1, 1, 1, in_feat}; + fc_params.input_offset = static_cast(input_offset.to()); + fc_params.filter_offset = static_cast(filter_offset.to()); + fc_params.output_offset = static_cast(output_offset.to()); + fc_params.activation.min = static_cast(activation_min.to()); + fc_params.activation.max = static_cast(activation_max.to()); + + cmsis_nn_per_tensor_quant_params per_tensor_quant_params; + per_tensor_quant_params.multiplier = + static_cast(requantize_multipliers.at(0)); + per_tensor_quant_params.shift = static_cast(requantize_shifts.at(0)); + + auto in_feat = input.size(input.dim() - 1); + auto out_feat = out.size(out.dim() - 1); + auto batches = 1; + for (size_t i = 0; i < input.dim() - 1; i++) { + batches *= input.size(i); + } + ET_LOG( + Info, + "in features: %d, out_features: %d, batches: %d, kernel_sum_size: %d", + in_feat, + out_feat, + batches, + kernel_sum.has_value() ? kernel_sum.value().numel() : 0); + ET_LOG( + Info, + "kernel_sum[0]: %d, kernel_sum[1]: %d", + kernel_sum_data != nullptr ? kernel_sum_data[0] : -1, + kernel_sum_data != nullptr ? kernel_sum_data[1] : -1); + cmsis_nn_dims input_dims = {batches, 1, 1, in_feat}; cmsis_nn_dims filter_dims = {in_feat, 1, 1, out_feat}; cmsis_nn_dims bias_dims = {1, 1, 1, out_feat}; - cmsis_nn_dims output_dims = {1, 1, 1, out_feat}; - - arm_cmsis_nn_status status; - for (int32_t b = 0; b < batch_size; b++) { - const int8_t* batch_input = input_data + b * in_feat; - int8_t* batch_output = output_data + b * out_feat; - - ET_CHECK_MSG( - batch_input != nullptr && weight_data != nullptr, - "Null input pointers"); - ET_CHECK_MSG(in_feat > 0 && out_feat > 0, "Invalid dimensions"); - - if (is_per_channel) { - cmsis_nn_per_channel_quant_params per_channel_quant_params; - per_channel_quant_params.multiplier = - const_cast(weight_mult_data); - per_channel_quant_params.shift = const_cast(weight_shift_data); - - status = arm_fully_connected_per_channel_s8( - &ctx, - &fc_params, - &per_channel_quant_params, - &input_dims, - batch_input, - &filter_dims, - weight_data, - &bias_dims, - bias_data, - &output_dims, - batch_output); - } else { - fc_params.filter_offset = -weight_zp_data[0]; - cmsis_nn_per_tensor_quant_params per_tensor_quant_params; - per_tensor_quant_params.multiplier = weight_mult_data[0]; - per_tensor_quant_params.shift = weight_shift_data[0]; - - status = arm_fully_connected_s8( - &ctx, - &fc_params, - &per_tensor_quant_params, - &input_dims, - batch_input, - &filter_dims, - weight_data, - &bias_dims, - bias_data, - &output_dims, - batch_output); - } - - if (status != ARM_CMSIS_NN_SUCCESS) { - ET_LOG( - Error, - "quantized_linear_out: CMSIS-NN failed with status [%d]", - status); - context.fail(Error::Internal); - return out; - } + cmsis_nn_dims output_dims = {batches, 1, 1, out_feat}; + + arm_cmsis_nn_status status = arm_fully_connected_s8( + &ctx, + &fc_params, + &per_tensor_quant_params, + &input_dims, + input_data, + &filter_dims, + weight_data, + &bias_dims, + bias_data, + &output_dims, + output_data); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_linear_out: CMSIS-NN failed with status [%d]", + status); + context.fail(Error::Internal); + return out; } - return out; -} -// Functional variant (stub, not used at runtime) -Tensor quantized_linear( - KernelRuntimeContext& context, - const Tensor& input, - const Scalar& input_zero_point, - const Scalar& input_multiplier, - const Scalar& input_shift, - const Tensor& weights, - const Tensor& weight_zero_point, - const Tensor& weight_multiplier, - const Tensor& weight_shift, - const torch::executor::optional& bias, - const Tensor& bias_multiplier, - const Tensor& bias_shift, - const Tensor& scratch_buffer, - const Scalar& output_zero_point, - const Scalar& in_features, - const Scalar& out_features) { - ET_LOG(Info, "quantized_linear: called"); - assert(false); - return const_cast(input); + return out; } } // namespace native diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 286f938ccc9..b8abfb9bde4 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from math import prod + import torch from executorch.backends.cortex_m.passes.passes_utils import ( requantize_cmsis, @@ -170,210 +172,110 @@ def quantized_add_impl( # QUANTIZED LINEAR OPERATION DEFINITION # =================================================================== - -def _check_per_tensor_or_per_channel(param: torch.Tensor, out_channels: int, name: str): - assert param.numel() in [ - 1, - out_channels, - ], f"{name} must be per-tensor (1) or per-channel ({out_channels}), got {param.numel()}" - - lib.define( "quantized_linear.out(" - "Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, " + "Tensor input, " "Tensor weights, " - "Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, " - "Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, " - "Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features, " - "*, Tensor(a!) out) -> Tensor(a!)" + "Tensor? bias, " + "Tensor? kernel_sum, " + "Scalar input_offset, " + "Scalar filter_offset, " + "Scalar output_offset, " + "int[] requantize_multipliers, " + "int[] requantize_shifts, " + "Scalar activation_max, " + "Scalar activation_min, " + "*, Tensor(a!) out" + ") -> Tensor(a!)" ) # Define functional variant (non-out version) lib.define( "quantized_linear(" - "Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, " + "Tensor input, " "Tensor weights, " - "Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, " - "Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, " - "Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features" + "Tensor? bias, " + "Tensor? kernel_sum, " + "Scalar input_offset, " + "Scalar filter_offset, " + "Scalar output_offset, " + "int[] requantize_multipliers, " + "int[] requantize_shifts, " + "Scalar activation_max, " + "Scalar activation_min" ") -> Tensor" ) -# Fake meta function for shape inference (out variant) -@register_fake("cortex_m::quantized_linear.out") -def quantized_linear_out_meta( - input: torch.Tensor, - input_zero_point: int, - input_multiplier: int, - input_shift: int, - weights: torch.Tensor, - weight_zero_point: torch.Tensor, - weight_multiplier: torch.Tensor, - weight_shift: torch.Tensor, - bias: torch.Tensor, - bias_multiplier: torch.Tensor, - bias_shift: torch.Tensor, - scratch_buffer: torch.Tensor, - output_zero_point: int, - in_features: int, - out_features: int, - out: torch.Tensor, -) -> torch.Tensor: - # Validate dimensions - batch_size = input.shape[0] - out_channels = weights.shape[0] - - # Validate weight quantization parameters dimensions - _check_per_tensor_or_per_channel( - weight_zero_point, out_channels, "weight_zero_point" - ) - _check_per_tensor_or_per_channel( - weight_multiplier, out_channels, "weight_multiplier" - ) - _check_per_tensor_or_per_channel(weight_shift, out_channels, "weight_shift") - - # Validate output shape - expected_shape = (batch_size, out_channels) - assert ( - out.shape == expected_shape - ), f"Output shape {out.shape} must be {expected_shape}" - - return out - - # Fake meta function for shape inference (functional variant) @register_fake("cortex_m::quantized_linear") def quantized_linear_meta( - input: torch.Tensor, - input_zero_point: int, - input_multiplier: int, - input_shift: int, - weights: torch.Tensor, - weight_zero_point: torch.Tensor, - weight_multiplier: torch.Tensor, - weight_shift: torch.Tensor, - bias: torch.Tensor, - bias_multiplier: torch.Tensor, - bias_shift: torch.Tensor, - scratch_buffer: torch.Tensor, - output_zero_point: int, - in_features: int, - out_features: int, -) -> torch.Tensor: - # Validate dimensions (same as out variant) - batch_size = input.shape[0] - out_channels = weights.shape[0] - - # Validate weight quantization parameters dimensions - _check_per_tensor_or_per_channel( - weight_zero_point, out_channels, "weight_zero_point" - ) - _check_per_tensor_or_per_channel( - weight_multiplier, out_channels, "weight_multiplier" - ) - _check_per_tensor_or_per_channel(weight_shift, out_channels, "weight_shift") - - # Calculate output shape for functional variant - output_shape = (batch_size, out_channels) - return torch.empty(output_shape, dtype=input.dtype, device=input.device) - - -@impl(lib, "quantized_linear.out", "CompositeExplicitAutograd") -def quantized_linear_out_impl( - input: torch.Tensor, - input_zero_point: int, - input_multiplier: int, - input_shift: int, - weights: torch.Tensor, - weight_zero_point: torch.Tensor, - weight_multiplier: torch.Tensor, - weight_shift: torch.Tensor, - bias: torch.Tensor, - bias_multiplier: torch.Tensor, - bias_shift: torch.Tensor, - scratch_buffer: torch.Tensor, - output_zero_point: int, - in_features: int, - out_features: int, - *, - out: torch.Tensor, + input, + weights, + bias, + kernel_sum, + input_offset, + filter_offset, + output_offset, + requantize_multipliers, + requantize_shifts, + activation_max, + activation_min, ) -> torch.Tensor: - """ - Fallback implementation for meta/testing - Note: This won't be called at runtime, only during compilation - """ - # Per-channel dequantization - input_scale = input_multiplier * (2.0 ** (-input_shift)) - input_fp = (input.float() - input_zero_point) * input_scale - if weight_zero_point.numel() == 1: - # Per-tensor - weight_scale = weight_multiplier.item() * (2.0 ** (-weight_shift.item())) - weights_fp = (weights.float() - weight_zero_point.item()) * weight_scale - else: - # Per-channel - weight_scales = weight_multiplier.float() * (2.0 ** (-weight_shift.float())) - weights_fp = ( - weights.float() - weight_zero_point.float().unsqueeze(1) - ) * weight_scales.unsqueeze(1) - bias_fp = None - if bias is not None: - bias_scales = bias_multiplier.float() * (2.0 ** (-bias_shift.float())) - bias_fp = bias.float() * bias_scales - - result_fp = torch.nn.functional.linear(input_fp, weights_fp, bias_fp) - else: - result_fp = torch.nn.functional.linear(input_fp, weights_fp) - result_quantized = torch.clamp( - torch.round(result_fp + output_zero_point), -128, 127 - ).to(torch.int8) - out.copy_(result_quantized) - return out + shape = (*input.shape[:-1], weights.shape[0]) + return torch.empty(shape, dtype=input.dtype, device=input.device) # Functional variant implementation @impl(lib, "quantized_linear", "CompositeExplicitAutograd") def quantized_linear_impl( input: torch.Tensor, - input_zero_point: int, - input_multiplier: int, - input_shift: int, weights: torch.Tensor, - weight_zero_point: torch.Tensor, - weight_multiplier: torch.Tensor, - weight_shift: torch.Tensor, bias: torch.Tensor, - bias_multiplier: torch.Tensor, - bias_shift: torch.Tensor, - scratch_buffer: torch.Tensor, - output_zero_point: int, - in_features: int, - out_features: int, + kernel_sum: torch.Tensor, + input_offset: int, + filter_offset: int, + output_offset: int, + requantize_multipliers: torch.Tensor, + requantize_shifts: torch.Tensor, + activation_max: int, + activation_min: int, ) -> torch.Tensor: """ Functional variant - creates output tensor and calls out variant """ - # Create output tensor - batch_size = input.shape[0] - output = torch.empty( - (batch_size, out_features), dtype=torch.int8, device=input.device - ) - return quantized_linear_out_impl( - input, - input_zero_point, - input_multiplier, - input_shift, - weights, - weight_zero_point, - weight_multiplier, - weight_shift, - bias, - bias_multiplier, - bias_shift, - scratch_buffer, - output_zero_point, - in_features, - out_features, - out=output, + + # Leaving both implementations for debugging purposes. + compute_using_kernel_sum = True + + if compute_using_kernel_sum: + weights_int32 = weights.to(torch.int32) + + input_int32 = input.to(torch.int32) + new_shape = (prod(input.shape[:-1]), input.shape[-1]) + input_reshaped = input_int32.reshape(new_shape) + + lhs_sum = torch.sum(input_reshaped, dim=-1, keepdim=True) * filter_offset + output = torch.mm(input_reshaped, weights_int32.T) + lhs_sum + kernel_sum + output_shape = (*input.shape[:-1], output.shape[-1]) + output_reshaped = output.reshape(output_shape) + else: + weights_int32 = weights.to(torch.int32) + filter_offset + + input_int32 = input.to(torch.int32) + input_offset + new_shape = (prod(input.shape[:-1]), input.shape[-1]) + input_reshaped = input_int32.reshape(new_shape) + + output = torch.mm(input_reshaped, weights_int32.T) + if bias is not None: + output = output + bias + output_shape = (*input.shape[:-1], output.shape[-1]) + output_reshaped = output.reshape(output_shape) + + output = requantize_cmsis( + output_reshaped, requantize_multipliers[0], requantize_shifts[0] ) + output += output_offset + output = torch.clamp(output, activation_min, activation_max).to(torch.int8) + return output diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index 81ebeafc778..98d8df8797e 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -23,14 +23,8 @@ - arg_meta: null kernel_name: cortex_m::quantized_add_out -- func: cortex_m::quantized_linear(Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, Tensor weights, Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features) -> Tensor +- func: cortex_m::quantized_linear.out(Tensor input, Tensor weights, Tensor? bias, Tensor? kernel_sum, Scalar input_offset, Scalar filter_offset, Scalar output_offset, int[] requantize_multipliers, int[] requantize_shifts, Scalar activation_max, Scalar activation_min, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null - kernel_name: cortex_m::quantized_linear - -- func: cortex_m::quantized_linear.out(Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, Tensor weights, Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features, *, Tensor(a!) out) -> Tensor(a!) - variants: function - kernels: - - arg_meta: null - kernel_name: cortex_m::quantized_linear_out + kernel_name: cortex_m::quantized_linear_out \ No newline at end of file diff --git a/backends/cortex_m/passes/cortex_m_pass_manager.py b/backends/cortex_m/passes/cortex_m_pass_manager.py index 02429cc68e0..10fb358c70e 100644 --- a/backends/cortex_m/passes/cortex_m_pass_manager.py +++ b/backends/cortex_m/passes/cortex_m_pass_manager.py @@ -4,7 +4,11 @@ # LICENSE file in the root directory of this source tree. -from executorch.backends.arm._passes import ScalarsToAttributePass +from executorch.backends.arm._passes import ( + DecorateFp32toInt32CastingPass, + FoldAndAnnotateQParamsPass, + ScalarsToAttributePass, +) from executorch.backends.cortex_m.passes import ( QuantizedLinearFusionPass, QuantizedOpFusionPass, @@ -20,10 +24,12 @@ class CortexMPassManager(XNNPACKPassManager): pass_list: list[ExportPass] = [ + FoldAndAnnotateQParamsPass, ReplaceScalarWithTensorArgPass, ReplaceQuantNodesPass, QuantizedOpFusionPass, QuantizedLinearFusionPass, + DecorateFp32toInt32CastingPass, ] pass_list_transform_for_annotation: list[ExportPass] = [ diff --git a/backends/cortex_m/passes/quantized_linear_fusion_pass.py b/backends/cortex_m/passes/quantized_linear_fusion_pass.py index 11a49beb2f4..0493d5faf90 100644 --- a/backends/cortex_m/passes/quantized_linear_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_linear_fusion_pass.py @@ -5,642 +5,147 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import logging -from typing import Optional import executorch.backends.cortex_m.ops.operators # noqa + import torch import torch.fx +from executorch.backends.cortex_m.passes.passes_utils import quantize_multiplier_aot -from executorch.backends.cortex_m.passes.passes_utils import ( - cleanup_nodes, - is_dequant_node, - quantize_multiplier_aot, - transfer_metadata, +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + get_param_tensor, ) -from executorch.backends.transforms.utils import create_mutable_buffer, get_param_tensor - from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass -from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops -from torch.fx import Node +from torch.export.graph_signature import InputKind from torch.fx.passes.infra.pass_manager import PassResult -logger = logging.getLogger("quantized_linear_fusion_pass") -logger.setLevel(logging.INFO) - class QuantizedLinearFusionPass(XNNPACKPass): """ Cortex-M backend pass that fuses quantized linear-like patterns. Fuses: dequantize -> [linear/addmm/fc_ops] -> quantize Into: cortex_m.quantized_linear.default with direct parameters. - """ - - SUPPORTED_OPS_MAPPING = { - exir_ops.edge.aten.addmm.default: exir_ops.edge.cortex_m.quantized_linear.default, - exir_ops.edge.aten.mm.default: exir_ops.edge.cortex_m.quantized_linear.default, - } - - requires_exported_program = True - - def __init__(self, exported_program: ExportedProgram): - super().__init__(exported_program) - self.nodes_to_erase = [] - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - logger.info("Starting QuantizedLinearFusionPass") - assert id(self._exported_program.graph_module.graph) == id( - graph_module.graph - ), "QuantizedLinearFusionPass requires same graph instance" - - try: - fusion_count = self._fuse_quantized_linear_patterns(graph_module) - if fusion_count > 0: - graph_module.graph.eliminate_dead_code() - graph_module.graph.lint() - graph_module.recompile() - logger.info(f"Linear fusion completed: {fusion_count} patterns fused") - return PassResult(graph_module, fusion_count > 0) - except Exception as e: - logger.error(f"Error in QuantizedLinearFusionPass: {e}") - raise e - - def _extract_linear_pattern(self, quantize_node: Node): - if not quantize_node.args: - return None - fc_node = quantize_node.args[0] - if not ( - fc_node.op == "call_function" - and fc_node.target in self.SUPPORTED_OPS_MAPPING - ): - return None - - op_name = str(fc_node.target).split(".")[-1] - - if "addmm" in str(fc_node.target): - input_dq_node = fc_node.args[1] - else: - input_dq_node = fc_node.args[0] - if not is_dequant_node(input_dq_node): - logger.info("input_dq_node is not a dequant node") - return None - weight_dq_node, bias_dq_node = self._extract_weight_bias_from_fc_op(fc_node) - if not weight_dq_node: - logger.info("No weight, bias dequantize node found") - return None - return ( - quantize_node, - fc_node, - input_dq_node, - weight_dq_node, - bias_dq_node, - op_name, - ) - - def _extract_weight_bias_from_fc_op(self, fc_node: Node): - """Generic extraction for FC-like operations.""" - - if "addmm" in str(fc_node.target): - if len(fc_node.args) >= 3: - bias_arg = fc_node.args[0] - weight_arg = fc_node.args[2] - weight_dq_node = self._trace_to_dequantize(weight_arg) - logger.info( - f"weight_arg: {weight_arg}, traced weight_dq_node: {weight_dq_node}" - ) - - if weight_dq_node is None: - logger.info("No weight dequantize node found ") - - # For bias, try to trace to dequantize but allow None (no-bias case) - bias_dq_node = self._trace_to_dequantize(bias_arg) - if bias_dq_node is None: - logger.info("No bias dequantize node found - likely no-bias linear") - return weight_dq_node, bias_dq_node - elif any(op in str(fc_node.target) for op in ["linear", "mm"]): - if len(fc_node.args) >= 2: - weight_arg = fc_node.args[1] - bias_arg = fc_node.args[2] if len(fc_node.args) > 2 else None - weight_dq_node = self._trace_to_dequantize(weight_arg) - bias_dq_node = self._trace_to_dequantize(bias_arg) if bias_arg else None - return weight_dq_node, bias_dq_node - return None, None - - def _extract_input_quantization_parameters( - self, input_dq_node: Node - ) -> Optional[dict]: - """Extract input quantization parameters from dequantize node.""" - try: - # Find the quantize operation that produces the int8 tensor - input_quantize_node = None - if hasattr(input_dq_node, "args") and input_dq_node.args: - quantize_candidate = input_dq_node.args[0] - if getattr( - quantize_candidate, "op", None - ) == "call_function" and "quantize" in str( - getattr(quantize_candidate, "target", "") - ): - input_quantize_node = quantize_candidate - - if not input_quantize_node: - logger.error("Could not find quantize node for input!") - return None - - # Extract input quantization parameters - input_scale = self._extract_param_value(input_dq_node.args[1]) - input_zero_point = int(self._extract_param_value(input_dq_node.args[2])) - input_multiplier, input_shift = quantize_multiplier_aot(input_scale) - - return { - "input_scale": input_scale, - "input_zero_point": input_zero_point, - "input_multiplier": input_multiplier, - "input_shift": input_shift, - "input_tensor": input_quantize_node, - } - except Exception as e: - logger.error(f"Failed to extract input quantization parameters: {e}") - return None - - def _extract_output_quantization_parameters( - self, quantize_node: Node - ) -> Optional[dict]: - """Extract output quantization parameters from quantize node.""" - try: - output_scale = self._extract_param_value(quantize_node.args[1]) - output_zero_point = int(self._extract_param_value(quantize_node.args[2])) - return { - "output_scale": output_scale, - "output_zero_point": output_zero_point, - } - except Exception as e: - logger.error(f"Failed to extract output quantization parameters: {e}") - return None + Note that the optimzed implementation makes use of the following rewrite: - def _create_constant_parameter_buffer( - self, graph, quantize_node: Node, data: torch.Tensor, name: str - ): - """Create a parameter buffer""" - buffer_name = f"{name}_{id(quantize_node)}" + Let + - yi be the output activations (y1, ... yn) + - xj be the input activations (x1, ... xm) + - wij be the weights (w11, ... wnm) + - a be the input offset + - b be the weight offset + - ci be the bias - setattr(graph.owning_module, buffer_name, data) - - # Create a get_attr node - with graph.inserting_before(quantize_node): - buffer_node = graph.create_node( - op="get_attr", target=buffer_name, name=buffer_name - ) - - # Set metadata - buffer_node.meta["val"] = data - - return buffer_node - - def _extract_weight_parameters(self, weight_dq_node: Node) -> Optional[dict]: - try: - weight_tensor = weight_dq_node.args[0] - weight_scale = weight_dq_node.args[1] - weight_zero_point = ( - weight_dq_node.args[2] if len(weight_dq_node.args) > 2 else None - ) - - weight_scale_data = self._extract_param_value(weight_scale) - weight_zp_data = ( - self._extract_param_value(weight_zero_point) - if weight_zero_point - else None - ) - - # Get actual tensor data to determine output features - weight_tensor_data = get_param_tensor(self._exported_program, weight_tensor) - out_features = weight_tensor_data.shape[0] - - # Handle both per-tensor and per-channel - if ( - isinstance(weight_scale_data, torch.Tensor) - and weight_scale_data.numel() > 1 - ): - # Per-channel: ensure we have the right number of elements - assert ( - weight_scale_data.numel() == out_features - ), f"Scale size {weight_scale_data.numel()} != out_features {out_features}" - - multipliers = [] - shifts = [] - for scale in weight_scale_data: - mult, shift = quantize_multiplier_aot(scale.item()) - multipliers.append(mult) - shifts.append(shift) - - weight_multiplier = torch.tensor(multipliers, dtype=torch.int32) - weight_shift = torch.tensor(shifts, dtype=torch.int32) - weight_zp_tensor = ( - weight_zp_data.int() - if weight_zp_data is not None - else torch.zeros(out_features, dtype=torch.int32) - ) - else: - # Per-tensor: create tensors with correct size for output features - scale_val = ( - weight_scale_data.item() - if isinstance(weight_scale_data, torch.Tensor) - else weight_scale_data - ) - mult, shift = quantize_multiplier_aot(scale_val) - - # Create tensors sized for out_features (not single element) - weight_multiplier = torch.full((out_features,), mult, dtype=torch.int32) - weight_shift = torch.full((out_features,), shift, dtype=torch.int32) - weight_zp_tensor = torch.full( - (out_features,), - weight_zp_data if weight_zp_data else 0, - dtype=torch.int32, - ) - - # Validate multipliers - for i, mult in enumerate(weight_multiplier): - if mult < (1 << 30) or mult > ((1 << 31) - 1): - logger.error( - f"Invalid multiplier[{i}]: {mult}, scale was: {weight_scale_data}" - ) - return None - - return { - "weight_tensor": weight_tensor, - "weight_zero_point_data": weight_zp_tensor, - "weight_multiplier_data": weight_multiplier, - "weight_shift_data": weight_shift, - } - except Exception as e: - logger.error(f"Failed to extract weight parameters: {e}") - return None - - def _extract_bias_parameters(self, bias_dq_node: Optional[Node]) -> Optional[dict]: - """ - Extract bias parameters for quantized linear fusion. - Handles both dequantized bias nodes and constant bias tensors. - Returns a dict with bias_tensor, bias_multiplier, and bias_shift. - """ - if not bias_dq_node: - # No bias present - return None - try: - # Case 1: Bias is a dequantize node - if hasattr(bias_dq_node, "op") and is_dequant_node(bias_dq_node): - bias_tensor = bias_dq_node.args[0] - bias_scale = bias_dq_node.args[1] + Then the linear operation can be written as: + yi = sum_j((xj + a) * (wij + b)) + ci + = sum_j(xj*wij + xj*b + a*wij + a*b) + ci + = sum_j(xj*wij) + sum_j(xj)*b + (a * sum_j(wij + b) + ci) + = sum_j(xj*wij) + sum_j(xj)*b + kernel_sum - bias_scale_data = self._extract_param_value(bias_scale) - - if ( - isinstance(bias_scale_data, torch.Tensor) - and bias_scale_data.numel() > 1 - ): - # Per-channel bias - bias_multipliers = [] - bias_shifts = [] - for scale_val in bias_scale_data.tolist(): - mult, shift = quantize_multiplier_aot(scale_val) - bias_multipliers.append(mult) - bias_shifts.append(shift) - return { - "bias_tensor": bias_tensor, - "bias_multiplier": bias_multipliers, - "bias_shift": bias_shifts, - } - else: - # Per-tensor bias - bias_scale_val = ( - bias_scale_data.item() - if isinstance(bias_scale_data, torch.Tensor) - else bias_scale_data - ) - bias_multiplier, bias_shift = quantize_multiplier_aot( - bias_scale_val - ) - return { - "bias_tensor": bias_tensor, - "bias_multiplier": bias_multiplier, - "bias_shift": bias_shift, - } - else: - # Case 2: Bias is a constant tensor (not dequantized) - # This can happen if bias is not quantized in the model - bias_tensor = bias_dq_node - # Use default multiplier/shift for unquantized bias - bias_multiplier = 1 - bias_shift = 0 - return { - "bias_tensor": bias_tensor, - "bias_multiplier": bias_multiplier, - "bias_shift": bias_shift, - } - except Exception as e: - logger.error(f"Failed to extract bias parameters: {e}") - return None - - def _prepare_bias_tensors( - self, bias_params: Optional[dict], out_features: int - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Prepare bias multiplier and shift tensors for kernel call. - Returns (bias_multiplier_tensor, bias_shift_tensor) both sized [out_features]. - """ - if bias_params: - bias_multiplier = bias_params["bias_multiplier"] - bias_shift = bias_params["bias_shift"] - - # Convert to tensors of the right size - if isinstance(bias_multiplier, int): - bias_multiplier_tensor = torch.full( - [out_features], bias_multiplier, dtype=torch.int32 - ) - elif isinstance(bias_multiplier, list): - assert ( - len(bias_multiplier) == out_features - ), f"Bias multiplier size {len(bias_multiplier)} != out_features {out_features}" - bias_multiplier_tensor = torch.tensor( - bias_multiplier, dtype=torch.int32 - ) - elif isinstance(bias_multiplier, torch.Tensor): - assert ( - bias_multiplier.numel() == out_features - ), f"Bias multiplier size {bias_multiplier.numel()} != out_features {out_features}" - bias_multiplier_tensor = bias_multiplier - else: - raise TypeError( - f"Unsupported bias_multiplier type: {type(bias_multiplier)}" - ) - - if isinstance(bias_shift, int): - bias_shift_tensor = torch.full( - [out_features], bias_shift, dtype=torch.int32 - ) - elif isinstance(bias_shift, list): - assert ( - len(bias_shift) == out_features - ), f"Bias shift size {len(bias_shift)} != out_features {out_features}" - bias_shift_tensor = torch.tensor(bias_shift, dtype=torch.int32) - elif isinstance(bias_shift, torch.Tensor): - assert ( - bias_shift.numel() == out_features - ), f"Bias shift size {bias_shift.numel()} != out_features {out_features}" - bias_shift_tensor = bias_shift - else: - raise TypeError(f"Unsupported bias_shift type: {type(bias_shift)}") - - return bias_multiplier_tensor, bias_shift_tensor - else: - # No bias: return zero tensors of correct shape - return ( - torch.zeros([out_features], dtype=torch.int32), - torch.zeros([out_features], dtype=torch.int32), - ) + where kernel_sum is precomputed aot. + """ - def _extract_param_value(self, node_or_value): - """ - Extract a scalar value from a Node or a direct float/int. + def _compute_kernel_sum(self, weights, bias, input_offset, weight_offset): """ - if isinstance(node_or_value, (float, int)): - return node_or_value - # If it's a tensor, get its scalar value if possible - if isinstance(node_or_value, torch.Tensor): - return node_or_value.item() if node_or_value.numel() == 1 else node_or_value - # If it's a Node, use get_param_tensor - if hasattr(node_or_value, "op"): - tensor = get_param_tensor(self._exported_program, node_or_value) - return tensor.item() if tensor.numel() == 1 else tensor - raise TypeError(f"Unsupported parameter type: {type(node_or_value)}") - - def _calculate_cmsis_scratch_size(self, weight_tensor) -> int: - """Calculate CMSIS-NN scratch buffer size for quantized linear operations. + Computes the precomputed kernel sum term (bias optional) + a * sum_j(wij + b) + ci - Source: CMSIS-NN arm_fully_connected_s8_get_buffer_size() returns filter_dims->w * sizeof(int32_t). - This buffer stores pre-computed kernel sums (weight row sums) - one int32_t per output feature. - Same buffer size applies to both per-tensor and per-channel quantization paths since both use - identical kernel sum optimization in the underlying matrix multiplication. + as defined above, for i = (1, ..., n) where j indexes the input activations. """ - try: - print(f"weight_tensor type: {type(weight_tensor)}, value: {weight_tensor}") - weight_shape = get_param_tensor(self._exported_program, weight_tensor).shape - out_features = weight_shape[0] # filter_dims->w in CMSIS terms - - # CMSIS-NN implementation expects the following size - cmsis_buffer_size = out_features * 4 # sizeof(int32_t) - return cmsis_buffer_size - except Exception as e: - logger.error(f"Failed to calculate CMSIS scratch size: {e}") - return 2048 # Fallback - - def _create_scratch_buffer(self, graph, quantize_node: Node, weight_tensor): - cmsis_scratch = self._calculate_cmsis_scratch_size(weight_tensor) - - kernel_sum_header = 8 # sizeof(KernelSumHeader) - total_size = kernel_sum_header + cmsis_scratch - - logger.info( - f"Kernel sum header: {kernel_sum_header}, CMSIS buffer: {cmsis_scratch}, total: {total_size}" - ) - - return create_mutable_buffer( - self._exported_program, - name=f"b_cmsis_linear_scratch_{id(quantize_node)}", - data=torch.zeros((total_size,), dtype=torch.int8), - ) - - def _create_fused_node( - self, - graph, - quantize_node: Node, - quant_params: dict, - weight_params: dict, - bias_params: Optional[dict], - quantized_target, - ) -> Node: - """Generic fused node creation for any FC-like operation.""" - # Extract all parameters - input_tensor = quant_params["input_tensor"] - input_zp = quant_params["input_zero_point"] - input_multiplier = quant_params["input_multiplier"] - input_shift = quant_params["input_shift"] - weight_tensor = weight_params["weight_tensor"] - - weight_zp_node = self._create_constant_parameter_buffer( - graph, quantize_node, weight_params["weight_zero_point_data"], "weight_zp" - ) - weight_mult_node = self._create_constant_parameter_buffer( - graph, quantize_node, weight_params["weight_multiplier_data"], "weight_mult" - ) - weight_shift_node = self._create_constant_parameter_buffer( - graph, quantize_node, weight_params["weight_shift_data"], "weight_shift" + weights_transposed = weights.T + weights_int32 = weights_transposed.to(torch.int32) + offset_weights = weights_int32 + weight_offset + kernel_sum = torch.sum(offset_weights, dim=0, keepdim=True, dtype=torch.int32) + kernel_sum_offset = kernel_sum * input_offset + + if bias is not None: + kernel_sum_offset += bias + + return kernel_sum_offset + + def _get_linear_replacement(self, args, meta, node) -> int: + input_scale = meta["input_qparams"][0].scale + input_zp = meta["input_qparams"][0].zp + weight_scale = meta["input_qparams"][1].scale + weight_zp = meta["input_qparams"][1].zp + output_scale = meta["output_qparams"][0].scale + output_zp = meta["output_qparams"][0].zp + output_min = meta["output_qparams"][0].qmin + output_max = meta["output_qparams"][0].qmax + + quantized_multiplier, quantized_shift = quantize_multiplier_aot( + (input_scale * weight_scale) / output_scale ) - # Get dimensions - weight_shape = get_param_tensor(self._exported_program, weight_tensor).shape - assert ( - len(weight_shape) == 2 - ), f"Weight tensor must be 2D, got shape {weight_shape}" - in_features = weight_shape[1] - out_features = weight_shape[0] - # Handle bias - bias_tensor = bias_params["bias_tensor"] if bias_params else None - bias_multiplier, bias_shift = self._prepare_bias_tensors( - bias_params, out_features + # TODO: Add support for configuring the backend to support other extensions. + # Kernel sum is only used in the CMSIS-NN implementation for the MVE extension, + # so this should be optional. + weights = args[1] + weights_tensor = get_param_tensor(self.exported_program, weights) + bias_tensor = ( + get_param_tensor(self.exported_program, args[2]) if len(args) > 2 else None ) - output_zp = quant_params["output_zero_point"] - - scratch_buffer = self._create_scratch_buffer( - graph, quantize_node, weight_tensor + kernel_sum_tensor = self._compute_kernel_sum( + weights_tensor, bias_tensor, -input_zp, -weight_zp ) - - with graph.inserting_after(quantize_node): - fused = graph.create_node( - "call_function", - target=quantized_target, - args=( - input_tensor, - input_zp, - input_multiplier, - input_shift, - weight_tensor, - weight_zp_node, - weight_mult_node, - weight_shift_node, - bias_tensor, - bias_multiplier, - bias_shift, - scratch_buffer, - output_zp, - in_features, - out_features, - ), - kwargs={}, + with node.graph.inserting_after(weights): + kernel_sum = create_constant_placeholder( + self.exported_program, + node.graph, + node.name + "_kernel_sum", + InputKind.PARAMETER, + kernel_sum_tensor, ) - transfer_metadata(fused, quantize_node, "QuantizedLinearFusionPass") - return fused - - def _mark_for_cleanup(self, nodes): - for node in nodes: - if node is not None: - self.nodes_to_erase.append(node) - - def _cleanup_nodes(self, graph): - cleanup_nodes(self.nodes_to_erase, graph) - self.nodes_to_erase.clear() - - def _extract_linear_pattern_with_validation(self, quantize_node: Node): - pattern_info = self._extract_linear_pattern(quantize_node) - if not pattern_info: - return None - # Optionally add more validation here if needed - return pattern_info + args = ( + args[0], + weights, + None, + kernel_sum, + -input_zp, + -weight_zp, + output_zp, + [quantized_multiplier], + [quantized_shift], + output_max, + output_min, + ) - def _trace_to_dequantize(self, node: Optional[Node], max_depth=3) -> Optional[Node]: - """Trace through transformations to find dequantize node.""" - current_node = node - depth = 0 - while current_node and depth < max_depth: - if is_dequant_node(current_node): - return current_node - if current_node.op == "call_function" and current_node.target in { - exir_ops.edge.aten.permute_copy.default, - exir_ops.edge.aten.view_copy.default, - }: - if current_node.args: - current_node = current_node.args[0] - depth += 1 - continue - break - return None + return args - def _fuse_quantized_linear_patterns( - self, graph_module: torch.fx.GraphModule - ) -> int: - fusion_count = 0 - graph = graph_module.graph - for node in list(graph.nodes): - if not ( - node.op == "call_function" and "quantize_per_tensor" in str(node.target) - ): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function": continue - pattern_info = self._extract_linear_pattern_with_validation(node) - if not pattern_info: + if node.target != exir_ops.edge.aten.linear.default: continue - - ( - quantize_node, - fc_node, - input_dq_node, - weight_dq_node, - bias_dq_node, - op_name, - ) = pattern_info - - # Get quantized target for this FC operation - quantized_target = self.SUPPORTED_OPS_MAPPING.get(fc_node.target) - if not quantized_target: - logger.warning(f"No quantized target found for {fc_node.target}") + if ( + node.meta.get("input_qparams", {}) == {} + or node.meta.get("output_qparams", {}) == {} + ): continue - logger.info(f"✅ Found complete cortex_m Q/DQ + {op_name} pattern!") - - try: - input_params = self._extract_input_quantization_parameters( - input_dq_node - ) - if not input_params: - logger.error( - "Quantization parameter extraction failed for node: %s", node - ) - return None - output_params = self._extract_output_quantization_parameters( - quantize_node + args = self._get_linear_replacement(node.args, node.meta, node) + with graph_module.graph.inserting_before(node): + cortex_m_linear = graph_module.graph.create_node( + "call_function", + target=exir_ops.edge.cortex_m.quantized_linear.default, + args=args, + kwargs={}, ) - if not output_params: - logger.error( - "Output quantization parameter extraction failed for node: %s", - node, - ) - return None - quant_params = {**input_params, **output_params} - logger.info(f"Quantization parameters: {quant_params}") - weight_params = self._extract_weight_parameters(weight_dq_node) - if not weight_params: - continue - bias_params = self._extract_bias_parameters(bias_dq_node) - if bias_dq_node and not bias_params: - continue - fused_node = self._create_fused_node( - graph, - quantize_node, - quant_params, - weight_params, - bias_params, - quantized_target, - ) - logger.info(f"Created fused {op_name} node: {fused_node}") + node.replace_all_uses_with(cortex_m_linear) + graph_module.graph.erase_node(node) - quantize_node.replace_all_uses_with(fused_node) - self._mark_for_cleanup( - [ - quantize_node, - fc_node, - input_dq_node, - weight_dq_node, - bias_dq_node, - ] - ) - fusion_count += 1 - logger.info(f"✅ Successfully fused {op_name} operation {fusion_count}") - except Exception as e: - logger.error( - f"Failed to fuse {op_name} pattern for {fc_node.name}: {e}" - ) - continue - self._cleanup_nodes(graph) - return fusion_count + modified = True + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/cortex_m/passes/quantized_op_fusion_pass.py b/backends/cortex_m/passes/quantized_op_fusion_pass.py index 888155dcfd0..202d91f27ce 100644 --- a/backends/cortex_m/passes/quantized_op_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_op_fusion_pass.py @@ -5,23 +5,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import logging -from typing import Set - -import executorch.backends.cortex_m.ops.operators # noqa -import torch +from typing import Dict from executorch.backends.cortex_m.passes.passes_utils import ( - extract_scalar_value, quantize_multiplier_aot, SHIFT_INT8, ) -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass -from torch.fx.passes.infra.pass_manager import PassResult -logger = logging.getLogger("quant_op_fusion_pass") -logger.setLevel(logging.INFO) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from torch.fx.node import Argument class QuantizedOpFusionPass(ExportPass): @@ -35,234 +29,58 @@ class QuantizedOpFusionPass(ExportPass): Supports multiple binary operations with backward compatibility for add. """ - # Generic operation mapping - SUPPORTED_OPS_MAPPING = { - exir_ops.edge.aten.add.Tensor: exir_ops.edge.cortex_m.quantized_add.default, - # Future binary ops to be added here: - } + def _get_add_replacement(self, args, meta) -> int: - def __init__(self): - super().__init__() + # Extract values + scale1 = meta["input_qparams"][0].scale + zero_point1 = meta["input_qparams"][0].zp + scale2 = meta["input_qparams"][1].scale + zero_point2 = meta["input_qparams"][1].zp + output_scale = meta["output_qparams"][0].scale + output_zero_point = meta["output_qparams"][0].zp - def _get_dequant_targets(self) -> Set: - """Support both decomposed and cortex_m dequant targets for flexible pass ordering.""" - return { - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.cortex_m.dequantize_per_tensor.default, - } + # AoT COMPUTATION: Calculate multipliers and shifts + max_scale_2x = 2 * max(scale1, scale2) - def _get_quant_targets(self) -> Set: - """Support both decomposed and cortex_m quant targets for flexible pass ordering.""" - return { - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.cortex_m.quantize_per_tensor.default, - } - - def _is_supported_binary_op(self, node: torch.fx.Node) -> bool: - """Check if node is a supported binary operation.""" - is_supported = ( - node.op == "call_function" and node.target in self.SUPPORTED_OPS_MAPPING + input1_mult, input1_shift = quantize_multiplier_aot(scale1 / max_scale_2x) + input2_mult, input2_shift = quantize_multiplier_aot(scale2 / max_scale_2x) + output_mult, output_shift = quantize_multiplier_aot( + max_scale_2x / (output_scale * (1 << SHIFT_INT8)) ) - if not is_supported: - return False - - shape1 = node.args[0].meta["val"].shape - shape2 = node.args[1].meta["val"].shape - is_broadcast = shape1 != shape2 - return not is_broadcast - def _is_dequant_node(self, node: torch.fx.Node) -> bool: - """Check if node is a dequantize operation.""" - return ( - hasattr(node, "op") - and node.op == "call_function" - and node.target in self._get_dequant_targets() + args = ( + args[0], + zero_point1, + input1_mult, + input1_shift, + args[1], + zero_point2, + input2_mult, + input2_shift, + output_zero_point, + output_mult, + output_shift, ) - def _is_quant_node(self, node: torch.fx.Node) -> bool: - """Check if node is a quantize operation.""" - return ( - hasattr(node, "op") - and node.op == "call_function" - and node.target in self._get_quant_targets() - ) + return exir_ops.edge.cortex_m.quantized_add.default, args - def _transfer_metadata( + def call_operator( self, - new_node: torch.fx.Node, - source_node: torch.fx.Node, - pass_name: str = "QuantizedOpFusionPass", - ) -> None: - """Metadata transfer with proper provenance tracking.""" - if hasattr(source_node, "meta") and source_node.meta: - new_node.meta = source_node.meta.copy() - - if "from_node" in new_node.meta: - from_node_list = new_node.meta.get("from_node", []).copy() - from_node_list.append( - {"source": source_node.name, "pass": pass_name, "op": "fuse"} - ) - new_node.meta["from_node"] = from_node_list - - # Copy essential fields - for field in ["tensor_meta", "stack_trace"]: - if field in source_node.meta: - new_node.meta[field] = source_node.meta[field] - - def _normalize_to_cortex_m_targets(self, graph_module: torch.fx.GraphModule) -> int: - """Convert decomposed targets to cortex_m equivalents for consistent handling.""" - target_mapping = { - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: exir_ops.edge.cortex_m.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: exir_ops.edge.cortex_m.quantize_per_tensor.default, - } - - normalization_count = 0 - for node in list(graph_module.graph.nodes): - if node.op == "call_function" and node.target in target_mapping: - logger.info(f"Normalizing {node.target} to cortex_m equivalent") - node.target = target_mapping[node.target] - normalization_count += 1 - - return normalization_count - - def _fuse_quantized_binary_patterns( - self, graph_module: torch.fx.GraphModule - ) -> int: - """Generic fusion for quantized binary operation patterns.""" - fusion_count = 0 - nodes_to_erase = [] - - for node in list(graph_module.graph.nodes): - if not self._is_quant_node(node): - continue - - quantize_node = node - if not quantize_node.args: - continue - - binary_op_node = quantize_node.args[0] - if not self._is_supported_binary_op(binary_op_node): - continue - - if len(binary_op_node.args) < 2: - continue - - dequant_node1, dequant_node2 = binary_op_node.args[:2] - if not ( - self._is_dequant_node(dequant_node1) - and self._is_dequant_node(dequant_node2) - ): - continue - - # Get the target quantized operation - quantized_target = self.SUPPORTED_OPS_MAPPING[binary_op_node.target] - # Extract op name (e.g., 'Tensor' -> 'add') - op_name = str(binary_op_node.target).split(".")[-1] - logger.info(f"✅ Found complete cortex_m Q/DQ + {op_name} pattern!") - - try: - # Extract values - int8_tensor1, scale1, zero_point1 = dequant_node1.args[:3] - int8_tensor2, scale2, zero_point2 = dequant_node2.args[:3] - output_scale, output_zero_point = quantize_node.args[1:3] - - # Convert to Python floats - scale1_val = extract_scalar_value(scale1) - scale2_val = extract_scalar_value(scale2) - output_scale_val = extract_scalar_value(output_scale) - zp1_val = int(extract_scalar_value(zero_point1)) - zp2_val = int(extract_scalar_value(zero_point2)) - output_zp_val = int(extract_scalar_value(output_zero_point)) - - max_scale_2x = 2 * max(scale1_val, scale2_val) - # AoT COMPUTATION: Calculate multipliers and shifts - - input1_mult, input1_shift = quantize_multiplier_aot( - scale1_val / max_scale_2x - ) - input2_mult, input2_shift = quantize_multiplier_aot( - scale2_val / max_scale_2x - ) - output_mult, output_shift = quantize_multiplier_aot( - max_scale_2x / (output_scale_val * (1 << SHIFT_INT8)) - ) - - logger.info("AoT computed parameters:") - logger.info(f" Input1: mult={input1_mult}, shift={input1_shift}") - logger.info(f" Input2: mult={input2_mult}, shift={input2_shift}") - logger.info(f" Output: mult={output_mult}, shift={output_shift}") - - with graph_module.graph.inserting_after(quantize_node): - fused = graph_module.graph.create_node( - "call_function", - target=quantized_target, - args=( - int8_tensor1, - zp1_val, - input1_mult, - input1_shift, - int8_tensor2, - zp2_val, - input2_mult, - input2_shift, - output_zp_val, - output_mult, - output_shift, - ), - kwargs={}, - ) - - # metadata transfer - self._transfer_metadata(fused, quantize_node) - - logger.info(f"✅ Created fused quantized_{op_name} node: {fused}") - - # Replace all uses - quantize_node.replace_all_uses_with(fused) - binary_op_node.replace_all_uses_with(fused) - dequant_node1.replace_all_uses_with(fused) - dequant_node2.replace_all_uses_with(fused) - - nodes_to_erase.extend( - [quantize_node, binary_op_node, dequant_node1, dequant_node2] - ) - fusion_count += 1 - logger.info(f"Pattern fused, total so far: {fusion_count}") - - except Exception as e: - logger.info(f"❌ Error during AoT computation: {e}") - logger.info(" Skipping fusion for this pattern") - continue - - for old_node in reversed(nodes_to_erase): - if old_node in graph_module.graph.nodes and len(old_node.users) == 0: - logger.info(f"🗑️ Erasing node: {old_node}") - graph_module.graph.erase_node(old_node) - - return fusion_count - - def call(self, graph_module: torch.fx.GraphModule): - logger.info("QuantizedOpFusionPass.call() started") - - # Normalize targets for flexible pass ordering - normalization_count = self._normalize_to_cortex_m_targets(graph_module) - - # Generic fusion for supported binary operations - fusion_count = self._fuse_quantized_binary_patterns(graph_module) - - total_changes = normalization_count + fusion_count - logger.info(f"Total changes: {total_changes}") - - if total_changes > 0: - graph_module.graph.eliminate_dead_code() - graph_module.graph.lint() - graph_module.recompile() - - logger.debug("=== AFTER FUSION: All nodes in the graph ===") - for i, node in enumerate(graph_module.graph.nodes): - logger.debug(f"Node {i}: op={node.op}, target={node.target}") - if "quantized_" in str(node.target) and "add" in str(node.target): - logger.debug(" ⭐ FOUND QUANTIZED BINARY OP NODE! ⭐") - logger.debug("=== END DEBUG ===") - - return PassResult(graph_module, total_changes > 0) + op: EdgeOpOverload, + args: tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if ( + meta.data.get("input_qparams", {}) == {} + or meta.data.get("output_qparams", {}) == {} + ): + return super().call_operator(op, args, {}, meta) + + match op: + case exir_ops.edge.aten.add.Tensor: + op, args = self._get_add_replacement(args, meta) + case _: + pass + + return super().call_operator(op, args, {}, meta) diff --git a/backends/cortex_m/test/ops/test_linear.py b/backends/cortex_m/test/ops/test_linear.py index 4ab5ca99f15..e81daa7e83e 100644 --- a/backends/cortex_m/test/ops/test_linear.py +++ b/backends/cortex_m/test/ops/test_linear.py @@ -4,8 +4,8 @@ # LICENSE file in the root directory of this source tree. -import pytest import torch +from executorch.backends.arm.test.common import parametrize from executorch.backends.cortex_m.test.tester import ( CortexMTester, McuTestCase, @@ -13,12 +13,9 @@ ) -class CortexMMm(torch.nn.Module): - def forward(self, x, y): - return torch.mm(x, y) - +class CortexMLinear(torch.nn.Module): ops_before_transforms = { - "executorch_exir_dialects_edge__ops_aten_mm_default": 1, + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, } @@ -29,32 +26,45 @@ def forward(self, x, y): "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, } + def __init__(self, *args, **kwargs): + super().__init__() + self.linear = torch.nn.Linear(*args, bias=False) + self.linear.weight.data.fill_(1.0) + + def forward(self, x): + return self.linear(x) -class CortexMBmm(torch.nn.Module): - def forward(self, x, y): - return torch.bmm(x, y) +class CortexMLinearX3(torch.nn.Module): ops_before_transforms = { - "executorch_exir_dialects_edge__ops_aten_bmm_default": 1, - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_aten_linear_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 7, } ops_after_transforms = { - "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 3, "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, } + def __init__(self, *args, **kwargs): + super().__init__() + self.linear = torch.nn.Linear(*args, bias=False) + self.linear.weight.data.fill_(1.0) + + def forward(self, x): + x = self.linear(x) + x = self.linear(x) + x = self.linear(x) + return x -class CortexMAddmm(torch.nn.Module): - def forward(self, x, y, z, alpha=None, beta=None): - return torch.addmm(beta, x, alpha, y, z) +class CortexMLinearBias(torch.nn.Module): ops_before_transforms = { - "executorch_exir_dialects_edge__ops_aten_addmm_default": 1, + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 4, } ops_after_transforms = { @@ -63,90 +73,23 @@ def forward(self, x, y, z, alpha=None, beta=None): "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, } - -class CortexMAt(CortexMMm): - def forward(self, x, y): - return x @ y - - -class CortexMMatmul(CortexMMm): - def forward(self, x, y): - return torch.matmul(x, y) - - -class CortexMLinear(CortexMMatmul): - def __init__(self, *args, **kwargs): - super().__init__() - self.linear = torch.nn.Linear(*args, bias=False) - - def forward(self, x): - return self.linear(x) - - -class CortexMLinearBias(CortexMAddmm): def __init__(self, *args, **kwargs): super().__init__() self.linear = torch.nn.Linear(*args, bias=True) self.relu = torch.nn.ReLU() def forward(self, x): - return self.relu(self.linear(x)) + return self.linear(x) test_cases = { - "mm": McuTestCase( - model=CortexMMm(), - example_inputs=( - ramp_tensor(0, 10, (1, 16)), - ramp_tensor(0, 10, (16, 16)), - ), - ), - "bmm": McuTestCase( - model=CortexMBmm(), - example_inputs=( - ramp_tensor(0, 10, (1, 16, 16)), - ramp_tensor(0, 10, (1, 16, 16)), - ), - ), - "addmm": McuTestCase( - model=CortexMAddmm(), - example_inputs=( - ramp_tensor(0, 10, (1, 16)), - ramp_tensor(0, 10, (16, 16)), - ramp_tensor(0, 10, (16, 16)), - 2, - 4, - ), - ), - "addmm_scalars": McuTestCase( - model=CortexMAddmm(), - example_inputs=( - ramp_tensor(0, 10, (1, 16)), - ramp_tensor(0, 10, (16, 16)), - ramp_tensor(0, 10, (16, 16)), - ), - ), - "@-operator": McuTestCase( - model=CortexMAt(), - example_inputs=( - ramp_tensor(0, 10, (1, 16)), - ramp_tensor(0, 10, (16, 16)), - ), - ), - "matmul": McuTestCase( - model=CortexMMatmul(), - example_inputs=( - ramp_tensor(0, 10, (1, 16)), - ramp_tensor(0, 10, (16, 16)), - ), - ), "linear_rank1": McuTestCase( - model=CortexMLinear(2, 3), - example_inputs=(ramp_tensor(-1, 1, (2,)),), + model=CortexMLinear(1, 2), + example_inputs=(torch.Tensor([1]),), ), "linear_rank2_pos": McuTestCase( - model=CortexMLinear(8, 3), - example_inputs=(ramp_tensor(0, 10, (2, 8)),), + model=CortexMLinear(1, 2), + example_inputs=(ramp_tensor(-1, 1, (1, 1)),), ), "linear_rank3_neg": McuTestCase( model=CortexMLinear(5, 3), @@ -164,22 +107,24 @@ def forward(self, x): model=CortexMLinearBias(61, 37), example_inputs=(ramp_tensor(0, 10, (8, 61)),), ), + "linear_x3": McuTestCase( + model=CortexMLinearX3(4, 4), + example_inputs=(ramp_tensor(0, 10, (2, 4)),), + ), } -@pytest.mark.skip( - reason="Skipping until the quantized_linear_fusion_pass is updated to work with non decomposed linear ops." -) +@parametrize("test_case", test_cases) def test_dialect_linear(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_dialect( - test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + test_case.model.ops_before_transforms, + test_case.model.ops_after_transforms, + qtol=1, ) -@pytest.mark.skip( - reason="Skipping until the quantized_linear_fusion_pass is updated to work with non decomposed linear ops." -) +@parametrize("test_case", test_cases) def test_implementation_linear(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) - tester.test_implementation() + tester.test_implementation(qtol=1) From b52728925d34015159af705968e7cf3e32d63ace Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Wed, 5 Nov 2025 10:07:02 +0100 Subject: [PATCH 2/3] Address PR comments Signed-off-by: Adrian Lundell --- .../ops/cmsis_scratch_buffer_context.h | 187 ++++++++++++++++++ backends/cortex_m/ops/op_quantized_linear.cpp | 2 +- .../passes/quantized_linear_fusion_pass.py | 2 +- .../passes/quantized_op_fusion_pass.py | 2 +- 4 files changed, 190 insertions(+), 3 deletions(-) create mode 100644 backends/cortex_m/ops/cmsis_scratch_buffer_context.h diff --git a/backends/cortex_m/ops/cmsis_scratch_buffer_context.h b/backends/cortex_m/ops/cmsis_scratch_buffer_context.h new file mode 100644 index 00000000000..4b9fdaebdf7 --- /dev/null +++ b/backends/cortex_m/ops/cmsis_scratch_buffer_context.h @@ -0,0 +1,187 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include "cortex_m_ops_common.h" +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { + +// During AOT phase, quantized_linear_fusion_pass allocates total buffer +// and passes in as 'Tensor'. (Total buffer = 8-byte header + x bytes) +// ┌─────────────────┬─────────────────────────────────────┐ +// │ KernelSum Header│ CMSIS Workspace │ +// │ (8 bytes) │ (x bytes) │ +// └─────────────────┴─────────────────────────────────────┘ +// │ │ +// │ └─> Passed to CMSIS API +// │ +// └─> State for kernel sum + +// C++ Runtime: +// ┌─────────────────┬─────────────────────────────────────┐ +// │ KernelSum Header│ CMSIS Workspace │ +// │ (8 bytes) │ (x bytes) │ +// └─────────────────┴─────────────────────────────────────┘ +// ^ ^ +// │ │ +// scratch_ptr cmsis_workspace_ptr +// │ │ +// ▼ ▼ +// arm_vector_sum_s8() writes kernel sums (with bias if avail): +// [sum₀+bias₀][sum₁+bias₁][sum₂+bias₂]...[sum_{n-1}+bias_{n-1}] +// (n * 4-byte int32_t values = x bytes) +// +// - n = out_features (number of output features) +// - x = n * 4 bytes (total CMSIS buffer size) +// - Total buffer = 8 + x bytes + +class CMSISScratchBufferContext final { + public: + CMSISScratchBufferContext( + Tensor& scratch_buffer, + const Tensor& weights, + const Tensor& weight_zero_point, + const torch::executor::optional& bias) + : scratch_ptr_(scratch_buffer.mutable_data_ptr()), + total_size_(scratch_buffer.size(0)), + base_ptr_(reinterpret_cast(scratch_ptr_)), + in_features_(weights.size(1)), + out_features_(weights.size(0)), + is_per_channel_(weight_zero_point.numel() > 1), + weight_data_offset_(calculate_offset(weights.const_data_ptr())), + weight_zp_data_offset_( + calculate_offset(weight_zero_point.const_data_ptr())), + bias_data_offset_( + bias.has_value() + ? calculate_offset(bias.value().const_data_ptr()) + : 0), + header_(reinterpret_cast(scratch_ptr_)), + cmsis_workspace_ptr_(scratch_ptr_ + KERNEL_SUM_HEADER_SIZE) { + cmsis_nn_dims filter_dims = {in_features_, 1, 1, out_features_}; + validate_size(filter_dims); + } + + cmsis_nn_context get_cmsis_ctx() const { + cmsis_nn_context ctx; + ET_CHECK_MSG( + reinterpret_cast(cmsis_workspace_ptr_) % 4 == 0, + "CMSIS workspace not 4-byte aligned"); + ctx.buf = cmsis_workspace_ptr_; + ctx.size = get_cmsis_workspace_size(); + return ctx; + } + + bool is_kernel_sum_updated() const { + return header_->updated; + } + + void compute_kernel_sums_if_needed() { + if (!header_->updated) { + arm_vector_sum_s8( + reinterpret_cast(cmsis_workspace_ptr_), + in_features_, + out_features_, + get_weight_data(), + get_weight_zp_data()[0], + 0, + get_bias_data()); + header_->updated = true; + ET_LOG( + Info, + "Computed kernel sums. [required_bytes : %d]", + header_->required_size); + } + } + + const int8_t* get_weight_data() const { + return reinterpret_cast(base_ptr_ + weight_data_offset_); + } + + const int32_t* get_weight_zp_data() const { + return reinterpret_cast(base_ptr_ + weight_zp_data_offset_); + } + + const int32_t* get_bias_data() const { + return bias_data_offset_ == 0 + ? nullptr + : reinterpret_cast(base_ptr_ + bias_data_offset_); + } + + bool is_per_channel_quant() const { + return is_per_channel_; + } + int32_t get_in_features() const { + return in_features_; + } + int32_t get_out_features() const { + return out_features_; + } + + private: + static constexpr size_t KERNEL_SUM_HEADER_SIZE = 8; + + // Header for kernel sum computation state only + struct KernelSumHeader { + bool updated = false; + int32_t required_size = 0; + }; + static_assert( + sizeof(KernelSumHeader) == KERNEL_SUM_HEADER_SIZE, + "KernelSumHeader must be exactly 8 bytes"); + + int8_t* scratch_ptr_; + size_t total_size_; + uint8_t* base_ptr_; + + // Context members + const int32_t in_features_; + const int32_t out_features_; + const bool is_per_channel_; + const uint32_t weight_data_offset_; + const uint32_t weight_zp_data_offset_; + const uint32_t bias_data_offset_; + + KernelSumHeader* header_; + int8_t* cmsis_workspace_ptr_; + + uint32_t calculate_offset(const void* ptr) const { + if (ptr == nullptr) + return 0; + + const uint8_t* ptr_bytes = reinterpret_cast(ptr); + ET_CHECK_MSG(ptr_bytes >= base_ptr_, "Pointer is before base address"); + + const std::ptrdiff_t offset = ptr_bytes - base_ptr_; + ET_CHECK_MSG( + offset >= 0 && offset <= UINT32_MAX, "Offset out of valid range"); + return static_cast(offset); + } + + size_t get_cmsis_workspace_size() const { + return total_size_ - KERNEL_SUM_HEADER_SIZE; + } + + void validate_size(const cmsis_nn_dims& filter_dims) const { + header_->required_size = + arm_fully_connected_s8_get_buffer_size(&filter_dims); + + ET_CHECK_MSG( + get_cmsis_workspace_size() >= + static_cast(header_->required_size), + "Scratch buffer size %zu insufficient for required size %d", + get_cmsis_workspace_size(), + header_->required_size); + } +}; + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/op_quantized_linear.cpp b/backends/cortex_m/ops/op_quantized_linear.cpp index ea01a8f772b..015fa805134 100644 --- a/backends/cortex_m/ops/op_quantized_linear.cpp +++ b/backends/cortex_m/ops/op_quantized_linear.cpp @@ -42,7 +42,7 @@ Tensor& quantized_linear_out( int8_t* output_data = out.mutable_data_ptr(); cmsis_nn_context ctx; - ctx.size = 2; + ctx.size = 0; // Not used in CMSIS-NN ctx.buf = kernel_sum_data; // Setup CMSIS-NN parameters diff --git a/backends/cortex_m/passes/quantized_linear_fusion_pass.py b/backends/cortex_m/passes/quantized_linear_fusion_pass.py index 0493d5faf90..f921f5ce621 100644 --- a/backends/cortex_m/passes/quantized_linear_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_linear_fusion_pass.py @@ -66,7 +66,7 @@ def _compute_kernel_sum(self, weights, bias, input_offset, weight_offset): return kernel_sum_offset - def _get_linear_replacement(self, args, meta, node) -> int: + def _get_linear_replacement(self, args, meta, node): input_scale = meta["input_qparams"][0].scale input_zp = meta["input_qparams"][0].zp weight_scale = meta["input_qparams"][1].scale diff --git a/backends/cortex_m/passes/quantized_op_fusion_pass.py b/backends/cortex_m/passes/quantized_op_fusion_pass.py index 202d91f27ce..df35c8d626a 100644 --- a/backends/cortex_m/passes/quantized_op_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_op_fusion_pass.py @@ -29,7 +29,7 @@ class QuantizedOpFusionPass(ExportPass): Supports multiple binary operations with backward compatibility for add. """ - def _get_add_replacement(self, args, meta) -> int: + def _get_add_replacement(self, args, meta): # Extract values scale1 = meta["input_qparams"][0].scale From dd7c05e78c04d700d6b9e5764f574976ad7ecd5a Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Thu, 6 Nov 2025 10:12:02 +0100 Subject: [PATCH 3/3] Address merge issues Fix a merge issue causing the build to fail + update tests after merging of #15590 Signed-off-by: Adrian Lundell --- backends/cortex_m/ops/op_quantized_linear.cpp | 17 +++++------ backends/cortex_m/test/ops/test_add.py | 29 ++++++++++--------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/backends/cortex_m/ops/op_quantized_linear.cpp b/backends/cortex_m/ops/op_quantized_linear.cpp index ee9dd4e8a28..015fa805134 100644 --- a/backends/cortex_m/ops/op_quantized_linear.cpp +++ b/backends/cortex_m/ops/op_quantized_linear.cpp @@ -21,16 +21,15 @@ Tensor& quantized_linear_out( KernelRuntimeContext& context, const Tensor& input, const Tensor& weights, - const Tensor& weight_zero_point, - const Tensor& weight_multiplier, - const Tensor& weight_shift, const torch::executor::optional& bias, - const Tensor& bias_multiplier, - const Tensor& bias_shift, - const Tensor& scratch_buffer, - const Scalar& output_zero_point, - const Scalar& in_features, - const Scalar& out_features, + const torch::executor::optional& kernel_sum, + const Scalar& input_offset, + const Scalar& filter_offset, + const Scalar& output_offset, + const IntArrayRef requantize_multipliers, + const IntArrayRef requantize_shifts, + const Scalar& activation_max, + const Scalar& activation_min, Tensor& out) { ET_LOG(Info, "quantized_linear_out: called"); diff --git a/backends/cortex_m/test/ops/test_add.py b/backends/cortex_m/test/ops/test_add.py index 4389b463076..458d5361347 100644 --- a/backends/cortex_m/test/ops/test_add.py +++ b/backends/cortex_m/test/ops/test_add.py @@ -59,17 +59,6 @@ class CortexMTensorAdd(Model): } -class CortexMTensorAddBroadcast(Model): - # TODO: Quantize and accelerate broadcasted adds - ops_before_transforms = { - "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, - } - - ops_after_transforms = { - "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, - } - - class CortexMAlphaAdd(ModelAlpha): ops_before_transforms = { "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, @@ -126,15 +115,15 @@ class CortexMAlphaAdd(ModelAlpha): (torch.rand(2, 2) * 10, torch.rand(2, 2)), ), "broadcast_1": McuTestCase( - CortexMTensorAddBroadcast(), + CortexMTensorAdd(), (torch.ones(1), torch.ones(2, 2, 2, 2)), ), "broadcast_2": McuTestCase( - CortexMTensorAddBroadcast(), + CortexMTensorAdd(), (torch.ones((2, 1, 1, 1)), torch.ones(1)), ), "broadcast_3": McuTestCase( - CortexMTensorAddBroadcast(), + CortexMTensorAdd(), ( ramp_tensor(-2, 2, (2, 1, 2, 1)), ramp_tensor(-5, 5, (1, 2, 1, 2)), @@ -183,6 +172,18 @@ def test_dialect_add(test_case): "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", AttributeError, ), + "broadcast_1": ( + " assert failed (input1.sizes() == input2.sizes()): Input1 and Input2 must have the same sizes.", + RuntimeError, + ), + "broadcast_2": ( + " assert failed (input1.sizes() == input2.sizes()): Input1 and Input2 must have the same sizes.", + RuntimeError, + ), + "broadcast_3": ( + " assert failed (input1.sizes() == input2.sizes()): Input1 and Input2 must have the same sizes.", + RuntimeError, + ), "alpha": ( "Expecting kwargs for aten op IR to be empty - alpha arg not supported.", AssertionError,