Skip to content

Commit 6e0c9f6

Browse files
Add Metal backend type definitions and utilities (#15019)
Implement foundational types and utilities for Metal backend including: - AOTI type aliases (AOTITensorHandle, AOTIRuntimeError, AOTITorchError) - Device type handling functions - Tensor storage size queries - Tensor attribute utilities
1 parent 5f046eb commit 6e0c9f6

File tree

9 files changed

+262
-42
lines changed

9 files changed

+262
-42
lines changed

backends/aoti/utils.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,64 @@ inline bool is_tensor_contiguous(
100100

101101
} // extern "C"
102102

103+
// Utility function to convert sizes pointer to vector
104+
inline std::vector<executorch::aten::SizesType> convert_sizes_to_vector(
105+
int64_t ndim,
106+
const int64_t* sizes_ptr) {
107+
std::vector<executorch::aten::SizesType> sizes(ndim);
108+
for (int i = 0; i < ndim; i++) {
109+
sizes[i] = static_cast<executorch::aten::SizesType>(sizes_ptr[i]);
110+
}
111+
return sizes;
112+
}
113+
114+
// Utility function to convert strides pointer to vector or calculate from sizes
115+
inline std::vector<executorch::aten::StridesType> convert_strides_to_vector(
116+
int64_t ndim,
117+
const int64_t* sizes_ptr,
118+
const int64_t* strides_ptr) {
119+
std::vector<executorch::aten::StridesType> strides(ndim);
120+
121+
if (strides_ptr != nullptr) {
122+
// Use provided strides.
123+
for (int64_t i = 0; i < ndim; i++) {
124+
strides[i] = static_cast<executorch::aten::StridesType>(strides_ptr[i]);
125+
}
126+
} else {
127+
// Calculate strides from sizes.
128+
if (ndim > 0) {
129+
strides[ndim - 1] = static_cast<executorch::aten::StridesType>(
130+
1); // Last dimension has stride 1
131+
for (int64_t i = ndim - 2; i >= 0; i--) {
132+
if (sizes_ptr[i + 1] == 0) {
133+
strides[i] = strides[i + 1]; // Copy stride when size is 0
134+
} else {
135+
strides[i] = static_cast<executorch::aten::StridesType>(
136+
static_cast<int64_t>(strides[i + 1]) * sizes_ptr[i + 1]);
137+
}
138+
}
139+
}
140+
}
141+
return strides;
142+
}
143+
144+
// Check if tensor is in contiguous memory format (NCHW for 4D tensors)
145+
// Contiguous format means strides decrease from left to right:
146+
// For NCHW: strides = [C*H*W, H*W, W, 1]
147+
inline bool is_contiguous_tensor(
148+
std::vector<executorch::aten::SizesType>& sizes,
149+
std::vector<executorch::aten::StridesType>& strides) {
150+
int64_t ndim = static_cast<int64_t>(strides.size());
151+
int64_t expected_stride = 1;
152+
for (int64_t i = ndim - 1; i >= 0; i--) {
153+
if (strides[i] != expected_stride) {
154+
return false;
155+
}
156+
expected_stride *= sizes[i];
157+
}
158+
return true;
159+
}
160+
103161
} // namespace aoti
104162
} // namespace backends
105163
} // namespace executorch
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/apple/metal/runtime/shims/tensor_attribute.h>
10+
#include <executorch/backends/apple/metal/runtime/shims/utils.h>
11+
#include <iostream>
12+
13+
namespace executorch {
14+
namespace backends {
15+
namespace metal {
16+
17+
extern "C" {
18+
19+
// Metal-specific device type constant
20+
__attribute__((__visibility__("default"))) int32_t
21+
aoti_torch_device_type_mps() {
22+
return 13; // Consistent with c10/core/DeviceType.h
23+
}
24+
25+
// Override aoti_torch_get_device_type to return MPS device type
26+
AOTITorchError aoti_torch_get_device_type(
27+
AOTITensorHandle tensor,
28+
int32_t* ret_device_type) {
29+
*ret_device_type = aoti_torch_device_type_mps();
30+
return Error::Ok;
31+
}
32+
33+
} // extern "C"
34+
35+
} // namespace metal
36+
} // namespace backends
37+
} // namespace executorch
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/aoti/common_shims.h>
12+
#include <executorch/backends/apple/metal/runtime/shims/types.h>
13+
14+
namespace executorch {
15+
namespace backends {
16+
namespace metal {
17+
18+
extern "C" {
19+
20+
// Metal-specific device type function
21+
int32_t aoti_torch_device_type_mps();
22+
23+
// Override aoti_torch_get_device_type to return MPS device type
24+
AOTITorchError aoti_torch_get_device_type(
25+
AOTITensorHandle tensor,
26+
int32_t* ret_device_type);
27+
28+
} // extern "C"
29+
30+
} // namespace metal
31+
} // namespace backends
32+
} // namespace executorch
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/extension/tensor/tensor.h>
12+
#include <executorch/runtime/core/error.h>
13+
#include <cstdint>
14+
15+
namespace executorch {
16+
namespace backends {
17+
namespace metal {
18+
19+
// Common using declarations for ExecutorTorch types
20+
using executorch::runtime::Error;
21+
using executorch::runtime::etensor::Tensor;
22+
23+
extern "C" {
24+
25+
// Common AOTI type aliases
26+
// Note: AOTITensorHandle is aliased to Tensor* for ExecutorTorch compatibility
27+
using AOTITensorHandle = Tensor*;
28+
using AOTIRuntimeError = Error;
29+
using AOTITorchError = Error;
30+
31+
} // extern "C"
32+
33+
} // namespace metal
34+
} // namespace backends
35+
} // namespace executorch
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/apple/metal/runtime/shims/utils.h>
10+
#include <executorch/runtime/platform/log.h>
11+
#include <cstdint>
12+
13+
namespace executorch {
14+
namespace backends {
15+
namespace metal {
16+
17+
extern "C" {
18+
19+
// Helper function to check if a dtype is supported in Metal backend
20+
bool is_dtype_supported_in_et_metal(int32_t dtype) {
21+
switch (dtype) {
22+
case static_cast<int32_t>(SupportedDTypes::INT64):
23+
case static_cast<int32_t>(SupportedDTypes::FLOAT32):
24+
case static_cast<int32_t>(SupportedDTypes::BFLOAT16):
25+
return true;
26+
default:
27+
return false;
28+
}
29+
}
30+
31+
// Metal-specific dtype validation utility function
32+
AOTITorchError validate_dtype(int32_t dtype) {
33+
if (is_dtype_supported_in_et_metal(dtype)) {
34+
return Error::Ok;
35+
}
36+
37+
ET_LOG(
38+
Error,
39+
"Unsupported dtype: %d. Supported dtypes: %d (int64), %d (float32), %d (bfloat16)",
40+
dtype,
41+
static_cast<int32_t>(SupportedDTypes::INT64),
42+
static_cast<int32_t>(SupportedDTypes::FLOAT32),
43+
static_cast<int32_t>(SupportedDTypes::BFLOAT16));
44+
return Error::InvalidArgument;
45+
}
46+
47+
} // extern "C"
48+
49+
} // namespace metal
50+
} // namespace backends
51+
} // namespace executorch
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/aoti/utils.h>
12+
#include <executorch/backends/apple/metal/runtime/shims/types.h>
13+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
14+
#include <cstdint>
15+
16+
namespace executorch {
17+
namespace backends {
18+
namespace metal {
19+
20+
// Enum for supported data types in et-metal backend
21+
enum class SupportedDTypes : int32_t {
22+
// UINT8 = 0, // PyTorch's uint8 dtype code
23+
// INT8 = 1, // PyTorch's int8 dtype code
24+
// INT16 = 2, // PyTorch's int16 dtype code
25+
// INT32 = 3, // PyTorch's int32 dtype code
26+
INT64 = 4, // PyTorch's int64 dtype code
27+
// FLOAT16 = 5, // PyTorch's float16 dtype code
28+
FLOAT32 = 6, // PyTorch's float32 dtype code
29+
// FLOAT64 = 7, // PyTorch's float64 dtype code
30+
// BOOL = 11, // PyTorch's bool dtype code
31+
BFLOAT16 = 15 // PyTorch's bfloat16 dtype code
32+
};
33+
34+
extern "C" {
35+
36+
// Helper function to check if a dtype is supported in Metal backend
37+
bool is_dtype_supported_in_et_metal(int32_t dtype);
38+
39+
// Metal-specific dtype validation utility function
40+
AOTITorchError validate_dtype(int32_t dtype);
41+
42+
} // extern "C"
43+
44+
} // namespace metal
45+
} // namespace backends
46+
} // namespace executorch

backends/cuda/runtime/shims/memory.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ using executorch::backends::aoti::aoti_torch_get_device_index;
2727
using executorch::backends::aoti::aoti_torch_get_dtype;
2828
using executorch::backends::aoti::aoti_torch_get_sizes;
2929
using executorch::backends::aoti::aoti_torch_get_strides;
30+
using executorch::backends::aoti::convert_sizes_to_vector;
31+
using executorch::backends::aoti::convert_strides_to_vector;
3032
using executorch::backends::aoti::dtype_to_element_size;
3133
using executorch::backends::aoti::dtype_to_scalar_type;
3234
using executorch::backends::aoti::validate_storage_offset;

backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <cuda_runtime.h>
1010
#include <executorch/backends/aoti/common_shims.h>
11+
#include <executorch/backends/aoti/utils.h>
1112
#include <executorch/backends/cuda/runtime/shims/memory.h>
1213
#include <executorch/backends/cuda/runtime/shims/tensor_attribute.h>
1314
#include <executorch/backends/cuda/runtime/utils.h>

backends/cuda/runtime/utils.h

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -71,48 +71,6 @@ enum class SupportedDevices : int32_t {
7171
CUDA = 1, // CUDA device
7272
};
7373

74-
// Utility function to convert sizes pointer to vector
75-
inline std::vector<executorch::aten::SizesType> convert_sizes_to_vector(
76-
int64_t ndim,
77-
const int64_t* sizes_ptr) {
78-
std::vector<executorch::aten::SizesType> sizes(ndim);
79-
for (int i = 0; i < ndim; i++) {
80-
sizes[i] = static_cast<executorch::aten::SizesType>(sizes_ptr[i]);
81-
}
82-
return sizes;
83-
}
84-
85-
// Utility function to convert strides pointer to vector or calculate from sizes
86-
inline std::vector<executorch::aten::StridesType> convert_strides_to_vector(
87-
int64_t ndim,
88-
const int64_t* sizes_ptr,
89-
const int64_t* strides_ptr) {
90-
std::vector<executorch::aten::StridesType> strides(ndim);
91-
92-
if (strides_ptr != nullptr) {
93-
// Use provided strides. it is ok if provided strides here is not contiguous
94-
// strides since it will be used internally in CUDA delegate.
95-
for (int64_t i = 0; i < ndim; i++) {
96-
strides[i] = static_cast<executorch::aten::StridesType>(strides_ptr[i]);
97-
}
98-
} else {
99-
// Calculate strides from sizes using ExecutorTorch's algorithm
100-
if (ndim > 0) {
101-
strides[ndim - 1] = static_cast<executorch::aten::StridesType>(
102-
1); // Last dimension has stride 1
103-
for (int64_t i = ndim - 2; i >= 0; i--) {
104-
if (sizes_ptr[i + 1] == 0) {
105-
strides[i] = strides[i + 1]; // Copy stride when size is 0
106-
} else {
107-
strides[i] = static_cast<executorch::aten::StridesType>(
108-
static_cast<int64_t>(strides[i + 1]) * sizes_ptr[i + 1]);
109-
}
110-
}
111-
}
112-
}
113-
return strides;
114-
}
115-
11674
extern "C" {
11775
using executorch::runtime::Error;
11876
// Common AOTI type aliases

0 commit comments

Comments
 (0)