@@ -100,6 +100,64 @@ inline bool is_tensor_contiguous(
100100
101101} // extern "C"
102102
103+ // Utility function to convert sizes pointer to vector
104+ inline std::vector<executorch::aten::SizesType> convert_sizes_to_vector (
105+ int64_t ndim,
106+ const int64_t * sizes_ptr) {
107+ std::vector<executorch::aten::SizesType> sizes (ndim);
108+ for (int i = 0 ; i < ndim; i++) {
109+ sizes[i] = static_cast <executorch::aten::SizesType>(sizes_ptr[i]);
110+ }
111+ return sizes;
112+ }
113+
114+ // Utility function to convert strides pointer to vector or calculate from sizes
115+ inline std::vector<executorch::aten::StridesType> convert_strides_to_vector (
116+ int64_t ndim,
117+ const int64_t * sizes_ptr,
118+ const int64_t * strides_ptr) {
119+ std::vector<executorch::aten::StridesType> strides (ndim);
120+
121+ if (strides_ptr != nullptr ) {
122+ // Use provided strides.
123+ for (int64_t i = 0 ; i < ndim; i++) {
124+ strides[i] = static_cast <executorch::aten::StridesType>(strides_ptr[i]);
125+ }
126+ } else {
127+ // Calculate strides from sizes.
128+ if (ndim > 0 ) {
129+ strides[ndim - 1 ] = static_cast <executorch::aten::StridesType>(
130+ 1 ); // Last dimension has stride 1
131+ for (int64_t i = ndim - 2 ; i >= 0 ; i--) {
132+ if (sizes_ptr[i + 1 ] == 0 ) {
133+ strides[i] = strides[i + 1 ]; // Copy stride when size is 0
134+ } else {
135+ strides[i] = static_cast <executorch::aten::StridesType>(
136+ static_cast <int64_t >(strides[i + 1 ]) * sizes_ptr[i + 1 ]);
137+ }
138+ }
139+ }
140+ }
141+ return strides;
142+ }
143+
144+ // Check if tensor is in contiguous memory format (NCHW for 4D tensors)
145+ // Contiguous format means strides decrease from left to right:
146+ // For NCHW: strides = [C*H*W, H*W, W, 1]
147+ inline bool is_contiguous_tensor (
148+ std::vector<executorch::aten::SizesType>& sizes,
149+ std::vector<executorch::aten::StridesType>& strides) {
150+ int64_t ndim = static_cast <int64_t >(strides.size ());
151+ int64_t expected_stride = 1 ;
152+ for (int64_t i = ndim - 1 ; i >= 0 ; i--) {
153+ if (strides[i] != expected_stride) {
154+ return false ;
155+ }
156+ expected_stride *= sizes[i];
157+ }
158+ return true ;
159+ }
160+
103161} // namespace aoti
104162} // namespace backends
105163} // namespace executorch
0 commit comments