Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN] Add op support validation for decomposed WebNN ops #23370

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 20 additions & 23 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,7 @@
std::unordered_set<const Node*> supported_nodes;

for (const auto& node : graph_viewer.Nodes()) {
bool supported = false;
// Firstly check if platform supports the WebNN op.
if (CheckSingleOp(node.OpType(), wnn_builder, device_type)) {
supported = IsNodeSupported(node, graph_viewer, device_type, wnn_limits, logger);
}
const bool supported = IsNodeSupported(node, graph_viewer, device_type, wnn_limits, logger);
LOGS(logger, VERBOSE) << "Operator type: [" << node.OpType()
<< "] index: [" << node.Index()
<< "] name: [" << node.Name()
Expand All @@ -125,7 +121,7 @@
return supported_nodes;
}

bool AreInputDataTypesSame(const std::string& op_type,
bool AreInputDataTypesSame(const std::string_view op_type,
gsl::span<const int32_t> input_types,
const logging::Logger& logger) {
for (size_t i = 1; i < input_types.size(); i++) {
Expand All @@ -145,46 +141,47 @@
if (it == onnx_to_webnn_data_type_map.end())
return false;

std::string webnn_data_type = it->second;
const std::string_view webnn_data_type = it->second;

// Check if WebNN supports the data type.
emscripten::val is_supported = webnn_supported_data_types.call<emscripten::val>("includes",
emscripten::val(webnn_data_type));
emscripten::val is_supported =
webnn_supported_data_types.call<emscripten::val>("includes", emscripten::val(std::string(webnn_data_type)));
return is_supported.as<bool>();
}

// Check if the input or output data type of ONNX node is supported by the WebNN operator.
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
bool IsDataTypeSupportedByOp(const std::string_view onnx_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const std::string_view webnn_input_output_name,
const std::string_view onnx_input_output_name,
const logging::Logger& logger) {
std::string webnn_op_type;
if (!GetWebNNOpType(onnx_op_type, webnn_op_type))
return false;
const std::string_view webnn_op_type = GetWebNNOpType(onnx_op_type);

return IsDataTypeSupportedByWebNNOp(onnx_op_type, webnn_op_type, onnx_data_type, wnn_limits,
return !webnn_op_type.empty() &&
IsDataTypeSupportedByWebNNOp(onnx_op_type, webnn_op_type, onnx_data_type, wnn_limits,
webnn_input_output_name, onnx_input_output_name, logger);
}

bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
const std::string& webnn_op_type,
bool IsDataTypeSupportedByWebNNOp(const std::string_view onnx_op_type,
const std::string_view webnn_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const std::string_view webnn_input_output_name,
const std::string_view onnx_input_output_name,
const logging::Logger& logger) {
if (wnn_limits[webnn_op_type].isUndefined()) {
if (wnn_limits[std::string(webnn_op_type)].isUndefined()) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] is not supported for now";
return false;
}
if (wnn_limits[webnn_op_type][webnn_input_output_name].isUndefined()) {

if (wnn_limits[std::string(webnn_op_type)][std::string(webnn_input_output_name)].isUndefined()) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] doesn't have parameter ["
<< webnn_input_output_name << "]";
return false;
}
if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) {
if (!IsSupportedDataType(
onnx_data_type, wnn_limits[std::string(webnn_op_type)][std::string(webnn_input_output_name)]["dataTypes"])) {

Check warning on line 184 in onnxruntime/core/providers/webnn/builders/helper.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/helper.cc:184: Add #include <string> for string [build/include_what_you_use] [4]
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] " << onnx_input_output_name << "'s data type: ["
<< onnx_data_type << "] is not supported by WebNN op [" << webnn_op_type << "] for now";
return false;
Expand Down
84 changes: 52 additions & 32 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,16 @@
const WebnnDeviceType device_type,
const emscripten::val& wnn_limits,
const logging::Logger& logger);
// TODO(@Honry): Some ONNX ops are supported by decomposed WebNN ops,
// we need to check the support of the decomposed ops.
static const InlinedHashMap<std::string, std::string> op_map = {

// Some ONNX ops are supported by decomposed WebNN ops.
const std::map<std::string_view, std::vector<std::string_view>> decomposed_op_map = {
{"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}},
{"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "split"}},
{"SimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}},
{"SkipSimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}},
};
// ONNX op type to WebNN op type mapping.
const std::map<std::string_view, std::string_view> op_map = {
{"Abs", "abs"},
{"Add", "add"},
{"And", "logicalAnd"},
Expand Down Expand Up @@ -247,7 +254,6 @@
{"Log", "log"},
{"LpPool", "l2Pool2d"},
{"LSTM", "lstm"},
{"LRN", "averagePool2d"},
{"MatMul", "matmul"},
{"MatMulInteger", "matmulInteger"},
{"Max", "max"},
Expand Down Expand Up @@ -275,17 +281,14 @@
{"Relu", "relu"},
{"Reshape", "reshape"},
{"Resize", "resample2d"},
{"RotaryEmbedding", "gather"},
{"ScatterElements", "scatterElements"},
{"ScatterND", "scatterND"},
{"Shape", "slice"},
{"Sigmoid", "sigmoid"},
{"Sign", "sign"},
{"SimplifiedLayerNormalization", "layerNormalization"},
{"Softplus", "softplus"},
{"Softsign", "softsign"},
{"Sin", "sin"},
{"SkipSimplifiedLayerNormalization", "layerNormalization"},
{"Slice", "slice"},
{"Softmax", "softmax"},
{"Split", "split"},
Expand All @@ -302,29 +305,46 @@
{"Xor", "logicalXor"},
};

inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder,
const WebnnDeviceType device_type) {
auto op_map_entry = op_map.find(op_type);
// Returns false if the op_type is not listed in the op_map or
// if the WebNN op has not been implemented in MLGraphBuilder in current browser.
if (op_map_entry == op_map.end() || !wnn_builder[op_map_entry->second].as<bool>()) {
return false;
}
// WebNN op name to its first input name mapping, only record the name that is different from "input".
// This map is used to determine the first input name of a WebNN op and is utilized by OpSupportLimits.
const std::map<std::string_view, std::string_view> webnn_op_first_input_name_map = {
{"add", "a"},
{"concat", "inputs"},
{"div", "a"},
{"equal", "a"},
{"gemm", "a"},
{"greater", "a"},
{"greaterOrEqual", "a"},
{"lesser", "a"},
{"lesserOrEqual", "a"},
{"logicalAnd", "a"},
{"logicalNot", "a"},
{"logicalOr", "a"},
{"logicalXor", "a"},
{"matmul", "a"},
{"max", "a"},
{"min", "a"},
{"mul", "a"},
{"pow", "a"},
{"sub", "a"},
{"where", "condition"},
};

return true;
// Retrieve the first input name of a WebNN op used for validating supported input data types.
// WebNN ops have various first input names such as 'a', 'input', 'inputs', etc.
// Special names other than 'input' are recorded in the webnn_op_first_input_name_map.
inline std::string_view GetWebNNOpFirstInputName(const std::string_view webnn_op_type) {
auto it = webnn_op_first_input_name_map.find(webnn_op_type);
return (it != webnn_op_first_input_name_map.end()) ? it->second : "input";
}

inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_type) {
inline std::string_view GetWebNNOpType(const std::string_view op_type) {
auto it = op_map.find(op_type);
// Returns false if the op_type is not listed in the op_map.
if (it == op_map.end()) {
return false;
}
webnn_op_type = it->second;
return true;
// Return an empty string if the op_type is not listed in the op_map.
return (it != op_map.end()) ? it->second : "";
}

static const InlinedHashMap<ONNX_NAMESPACE::TensorProto_DataType, std::string> onnx_to_webnn_data_type_map = {
const std::map<ONNX_NAMESPACE::TensorProto_DataType, std::string_view> onnx_to_webnn_data_type_map = {

Check warning on line 347 in onnxruntime/core/providers/webnn/builders/helper.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <map> for map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/helper.h:347: Add #include <map> for map<> [build/include_what_you_use] [4]
{ONNX_NAMESPACE::TensorProto_DataType_INT4, "int4"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT4, "uint4"},
{ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"},
Expand All @@ -338,22 +358,22 @@
{ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"},
};

bool AreInputDataTypesSame(const std::string& op_type,
bool AreInputDataTypesSame(const std::string_view op_type,
gsl::span<const int32_t> input_types,
const logging::Logger& logger);
bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types);
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
bool IsDataTypeSupportedByOp(const std::string_view onnx_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const std::string_view webnn_input_output_name,
const std::string_view onnx_input_output_name,
const logging::Logger& logger);
bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
const std::string& webnn_op_type,
bool IsDataTypeSupportedByWebNNOp(const std::string_view onnx_op_type,
const std::string_view webnn_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const std::string_view webnn_input_output_name,
const std::string_view onnx_input_output_name,
const logging::Logger& logger);

bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,17 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializ
const logging::Logger& logger) const {
// We only check the type of input 0 by default, specific op builder can override this.
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
const std::string_view webnn_op_type = GetWebNNOpType(op_type);
if (webnn_op_type.empty())
return false;

return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger);
const std::string_view webnn_input_name = GetWebNNOpFirstInputName(webnn_op_type);
return IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, input_type, wnn_limits,
webnn_input_name, "input", logger);
}

bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits,
Expand All @@ -83,7 +88,7 @@ bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node,
const logging::Logger& logger) const {
// We only check the type of output 0 by default, specific op builder can override this.
const auto& output = *node.OutputDefs()[0];
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t output_type;
if (!GetType(output, output_type, logger))
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
bool BinaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input0_type;
int32_t input1_type;

Expand Down
24 changes: 0 additions & 24 deletions onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ class CastOpBuilder : public BaseOpBuilder {
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
private:
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -85,25 +80,6 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
return Status::OK();
}

// Operator support related.
bool CastOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input_type;

if (!GetType(*input_defs[0], input_type, logger))
return false;

if (!IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "input", logger))
return false;

NodeAttrHelper helper(node);
// Check cast to type.
const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED);
return IsDataTypeSupportedByOp(op_type, to_type, wnn_limits, "output", "to", logger);
}

void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<CastOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
bool ConcatOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input0_type;

if (!GetType(*input_defs[0], input0_type, logger))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
bool ConvOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input0_type; // input data type
int32_t input1_type; // weight data type
int32_t input2_type; // bias or x_zero_point data type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* init
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();

const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input0_type;
int32_t input1_type;
bool has_input1 = TensorExists(input_defs, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ bool GatherElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet&
const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();

int32_t data_type;
int32_t indices_type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ bool GatherNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* in
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();

int32_t data_type;
int32_t indices_type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ bool GatherOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* init
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input_type;
int32_t indices_type;
if (!GetType(input, input_type, logger) ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializer
bool GemmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input0_type; // A data type
int32_t input1_type; // B data type
int32_t input2_type; // C or a_zero_point data type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c
bool GruOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input_X_type = 0; // input data type
int32_t input_W_type = 0; // weight data type
int32_t input_R_type = 0; // recurrent weight data type
Expand Down Expand Up @@ -226,7 +226,7 @@ bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& output_defs = node.OutputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t Y_type = 0;
int32_t Y_h_type = 0;
bool has_Y = TensorExists(output_defs, 0);
Expand Down
Loading
Loading