Skip to content

Commit 77b312a

Browse files
committed
[Backend][Relax] Add Intel GNA backend for NPU support
Intel GNA (Gaussian Neural Accelerator) hardware is present in Intel Core Ultra processors. This backend enables TVM users to target Intel NPU hardware via OpenVINO's GNA integration. Features: - Pattern-based graph partitioning for GNA-compatible operations - JSON serialization for OpenVINO runtime integration - Support for dense/linear, 1D convolution, and ReLU operations - Automatic shape and dtype extraction for optimization - Comprehensive test coverage Supported operations: - Dense/Linear layers (relax.matmul) - 1D Convolution (relax.nn.conv1d) - ReLU activation (relax.nn.relu) This provides community access to Intel NPU acceleration through TVM's compilation pipeline.
1 parent 789e0b8 commit 77b312a

File tree

8 files changed

+795
-0
lines changed

8 files changed

+795
-0
lines changed

CMakeLists.txt

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ tvm_option(USE_BLAS "The blas library to be linked" none)
8888
tvm_option(USE_AMX "Enable Intel AMX" OFF)
8989
tvm_option(USE_MKL "MKL root path when use MKL blas" OFF)
9090
tvm_option(USE_DNNL "Enable DNNL codegen" OFF)
91+
tvm_option(USE_GNA "Enable Intel GNA codegen" OFF)
9192
tvm_option(USE_CUDNN "Build with cuDNN" OFF)
9293
tvm_option(USE_CUBLAS "Build with cuBLAS" OFF)
9394
tvm_option(USE_NVTX "Build with NVTX" OFF)
@@ -327,6 +328,8 @@ tvm_file_glob(GLOB DATATYPE_SRCS src/target/datatype/*.cc)
327328
list(APPEND COMPILER_SRCS ${DATATYPE_SRCS})
328329
list(APPEND COMPILER_SRCS "src/target/datatype/myfloat/myfloat.cc")
329330

331+
list(APPEND COMPILER_SRCS "src/relax/backend/contrib/gna/codegen.cc")
332+
330333
tvm_file_glob(GLOB RUNTIME_SRCS
331334
src/runtime/*.cc
332335
src/runtime/vm/*.cc
@@ -389,6 +392,40 @@ if (USE_CUDA AND USE_NVSHMEM)
389392
list(APPEND RUNTIME_SRCS ${RUNTIME_NVSHMEM_SRCS})
390393
endif()
391394

395+
if(USE_GNA)
396+
message(STATUS "Build with Intel GNA...")
397+
tvm_file_glob(GLOB RUNTIME_GNA_SRCS src/runtime/contrib/gna/*.cc)
398+
list(APPEND RUNTIME_SRCS ${RUNTIME_GNA_SRCS})
399+
400+
find_path(GNA_INCLUDE_DIR gna2-api.h HINTS ../gna/src/gna-api)
401+
if(NOT GNA_INCLUDE_DIR)
402+
message(FATAL_ERROR "Cannot find GNA headers. Expected gna2-api.h in ../gna/src/gna-api")
403+
endif()
404+
405+
find_path(GNA_LIB_DIR NAMES gna.dll gna.so libgna.so HINTS
406+
../gna/bin/gna-lib/WIN-DEBUG/x64
407+
../gna/bin/gna-lib/WIN-RELEASE/x64
408+
../gna/bin/gna-lib/LNX-DEBUG/x64
409+
../gna/bin/gna-lib/LNX-RELEASE/x64
410+
../gna/build/src/gna-lib)
411+
412+
if(GNA_LIB_DIR)
413+
message(STATUS "Found GNA library directory: ${GNA_LIB_DIR}")
414+
else()
415+
message(WARNING "GNA library not found. Build GNA first: cd ../gna && mkdir -p build && cd build && cmake .. && make")
416+
endif()
417+
418+
include_directories(${GNA_INCLUDE_DIR})
419+
if(GNA_LIB_DIR)
420+
link_directories(${GNA_LIB_DIR})
421+
if(WIN32)
422+
list(APPEND TVM_RUNTIME_LINKER_LIBS gna.lib)
423+
else()
424+
list(APPEND TVM_RUNTIME_LINKER_LIBS gna)
425+
endif()
426+
endif()
427+
endif()
428+
392429
if(USE_ROCM AND USE_RCCL)
393430
message(STATUS "Build with RCCL...")
394431
find_rccl(${USE_RCCL})

cmake/modules/LibInfo.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ function(add_lib_info src_file)
129129
TVM_INFO_USE_NVSHMEM="${USE_NVSHMEM}"
130130
TVM_INFO_USE_NNAPI_CODEGEN="${USE_NNAPI_CODEGEN}"
131131
TVM_INFO_USE_NNAPI_RUNTIME="${USE_NNAPI_RUNTIME}"
132+
TVM_INFO_USE_GNA="${USE_GNA}"
132133
TVM_INFO_BACKTRACE_ON_SEGFAULT="${BACKTRACE_ON_SEGFAULT}"
133134
)
134135

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Pattern table and codegen for GNA"""
18+
19+
from . import gna # noqa: F401
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Pattern table for GNA backend"""
18+
19+
from tvm.relax.dpl.pattern import is_op, wildcard
20+
from tvm.relax.transform import PatternCheckContext
21+
22+
from ...pattern_registry import register_patterns
23+
24+
25+
def _check_default(context: PatternCheckContext) -> bool: # pylint: disable=unused-argument
26+
return True
27+
28+
29+
def linear_patterns():
30+
"""
31+
Returns a list of linear/dense patterns in GNA BYOC backend.
32+
"""
33+
34+
def _make_linear_pattern():
35+
input0 = wildcard()
36+
weight = wildcard()
37+
out = is_op("relax.matmul")(input0, weight)
38+
annotations = {"input": input0, "weight": weight, "root": out}
39+
return out, annotations
40+
41+
def _linear_pattern(pattern_name):
42+
return (pattern_name, *_make_linear_pattern(), _check_default)
43+
44+
return [_linear_pattern("gna.dense")]
45+
46+
47+
def conv1d_patterns():
48+
"""
49+
Returns a list of conv1d patterns in GNA BYOC backend.
50+
"""
51+
52+
def _make_conv1d_pattern():
53+
input0 = wildcard()
54+
weight = wildcard()
55+
out = is_op("relax.nn.conv1d")(input0, weight)
56+
annotations = {"input": input0, "weight": weight, "root": out}
57+
return out, annotations
58+
59+
def _conv1d_pattern(pattern_name):
60+
return (pattern_name, *_make_conv1d_pattern(), _check_default)
61+
62+
return [_conv1d_pattern("gna.conv1d")]
63+
64+
65+
def activation_patterns():
66+
"""
67+
Returns a list of activation patterns in GNA BYOC backend.
68+
"""
69+
70+
def _make_activation_pattern():
71+
input0 = wildcard()
72+
out = is_op("relax.nn.relu")(input0)
73+
annotations = {"input": input0, "root": out}
74+
return out, annotations
75+
76+
def _activation_pattern(pattern_name):
77+
return (pattern_name, *_make_activation_pattern(), _check_default)
78+
79+
return [_activation_pattern("gna.relu")]
80+
81+
82+
register_patterns(
83+
[
84+
*linear_patterns(),
85+
*conv1d_patterns(),
86+
*activation_patterns(),
87+
]
88+
)
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file src/relax/backend/contrib/gna/codegen.cc
22+
* \brief Implementation of the GNA JSON serializer.
23+
*/
24+
#include <tvm/ffi/reflection/registry.h>
25+
#include <tvm/ir/module.h>
26+
#include <tvm/relax/expr.h>
27+
#include <tvm/relax/struct_info.h>
28+
#include <tvm/runtime/module.h>
29+
#include <tvm/runtime/packed_func.h>
30+
#include <tvm/tir/expr.h>
31+
32+
#include <string>
33+
34+
#include "../codegen_json/codegen_json.h"
35+
#include "../utils.h"
36+
37+
namespace tvm {
38+
namespace relax {
39+
namespace contrib {
40+
41+
using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
42+
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
43+
using JSONSerializer = backend::contrib::JSONSerializer;
44+
using backend::contrib::NodeEntries;
45+
46+
class GNAJSONSerializer : public JSONSerializer {
47+
public:
48+
GNAJSONSerializer(Map<Constant, String> constant_names, Map<Var, Expr> bindings)
49+
: JSONSerializer(constant_names), bindings_(bindings) {}
50+
51+
using JSONSerializer::VisitExpr_;
52+
53+
NodeEntries VisitExpr_(const CallNode* call_node) final {
54+
const auto* fn_var = call_node->op.as<VarNode>();
55+
ICHECK(fn_var);
56+
const auto fn = Downcast<Function>(bindings_[GetRef<Var>(fn_var)]);
57+
ICHECK(fn.defined()) << "Expects the callee to be a function.";
58+
59+
auto composite_opt = fn->GetAttr<String>(attr::kComposite);
60+
ICHECK(composite_opt.has_value()) << "Only composite functions are supported.";
61+
62+
std::string composite_name = composite_opt.value();
63+
64+
NodeEntries inputs;
65+
for (const auto& arg : call_node->args) {
66+
auto res = VisitExpr(arg);
67+
inputs.insert(inputs.end(), res.begin(), res.end());
68+
}
69+
70+
auto node = std::make_shared<JSONGraphNode>(composite_name, /* name_ */
71+
"kernel", /* op_type_ */
72+
inputs, 1 /* num_outputs_ */);
73+
74+
const CallNode* root_call = nullptr;
75+
if (composite_name.find("gna.dense") != std::string::npos) {
76+
root_call = backend::GetOpInFunction(fn, "relax.matmul");
77+
} else if (composite_name.find("gna.conv1d") != std::string::npos) {
78+
root_call = backend::GetOpInFunction(fn, "relax.nn.conv1d");
79+
} else if (composite_name.find("gna.relu") != std::string::npos) {
80+
root_call = backend::GetOpInFunction(fn, "relax.nn.relu");
81+
} else {
82+
LOG(FATAL) << "Unimplemented GNA pattern: " << composite_name;
83+
}
84+
85+
SetCallNodeAttribute(node, root_call);
86+
return AddNode(node, GetRef<Expr>(call_node));
87+
}
88+
89+
private:
90+
/*! \brief The bindings to look up composite functions. */
91+
Map<Var, Expr> bindings_;
92+
93+
void SetCallNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* call) {
94+
// First call the base implementation to extract standard attributes
95+
JSONSerializer::SetCallNodeAttribute(node, call);
96+
97+
// Add GNA-specific attributes based on the operation
98+
if (call && call->op.as<OpNode>()) {
99+
auto op = Downcast<Op>(call->op);
100+
std::string op_name = op->name;
101+
102+
// Extract shape information from struct_info
103+
if (!call->args.empty()) {
104+
StructInfo input_sinfo = GetStructInfo(call->args[0]);
105+
if (const auto* tensor_sinfo = input_sinfo.as<TensorStructInfoNode>()) {
106+
if (tensor_sinfo->shape.defined()) {
107+
std::vector<std::string> shape_strs;
108+
ShapeExpr shape = Downcast<ShapeExpr>(tensor_sinfo->shape.value());
109+
for (const PrimExpr& dim : shape->values) {
110+
if (const auto* int_imm = dim.as<tvm::tir::IntImmNode>()) {
111+
shape_strs.push_back(std::to_string(int_imm->value));
112+
} else {
113+
shape_strs.push_back("-1");
114+
}
115+
}
116+
std::vector<dmlc::any> shape_attr;
117+
shape_attr.emplace_back(shape_strs);
118+
node->SetAttr("input_shape", shape_attr);
119+
}
120+
121+
std::vector<std::string> dtype_strs{tensor_sinfo->dtype.code() == kDLFloat ? "float32"
122+
: "int32"};
123+
std::vector<dmlc::any> dtype_attr;
124+
dtype_attr.emplace_back(dtype_strs);
125+
node->SetAttr("input_dtype", dtype_attr);
126+
}
127+
}
128+
129+
if (op_name == "relax.nn.conv1d") {
130+
if (call->attrs.defined()) {
131+
std::vector<std::string> op_attrs{"conv1d_op"};
132+
std::vector<dmlc::any> op_attr;
133+
op_attr.emplace_back(op_attrs);
134+
node->SetAttr("gna_op_type", op_attr);
135+
}
136+
} else if (op_name == "relax.matmul") {
137+
std::vector<std::string> op_attrs{"dense_op"};
138+
std::vector<dmlc::any> op_attr;
139+
op_attr.emplace_back(op_attrs);
140+
node->SetAttr("gna_op_type", op_attr);
141+
} else if (op_name == "relax.nn.relu") {
142+
std::vector<std::string> op_attrs{"activation_op"};
143+
std::vector<dmlc::any> op_attr;
144+
op_attr.emplace_back(op_attrs);
145+
node->SetAttr("gna_op_type", op_attr);
146+
}
147+
}
148+
}
149+
};
150+
151+
/*!
152+
* \brief Create a GNA JSON runtime module.
153+
* \param functions The functions to be compiled.
154+
* \param unused Unused config options.
155+
* \param constant_names The constant names to be used.
156+
* \return Array of runtime modules.
157+
*/
158+
Array<runtime::Module> GNACompiler(Array<Function> functions, Map<String, ffi::Any> /*unused*/,
159+
Map<Constant, String> constant_names) {
160+
Array<runtime::Module> compiled_functions;
161+
162+
for (const auto& func : functions) {
163+
GNAJSONSerializer serializer(constant_names, AnalyzeVar2Value(func));
164+
serializer.serialize(func);
165+
auto graph_json = serializer.GetJSON();
166+
auto constant_names_used = serializer.GetConstantNames();
167+
168+
const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.GNAJSONRuntimeCreate");
169+
auto func_name = GetExtSymbol(func);
170+
compiled_functions.push_back(
171+
pf(func_name, graph_json, constant_names_used).cast<runtime::Module>());
172+
}
173+
174+
return compiled_functions;
175+
}
176+
177+
// Register the external codegen entrypoint via FFI reflection (new TVM registry)
178+
TVM_FFI_STATIC_INIT_BLOCK({
179+
namespace refl = tvm::ffi::reflection;
180+
refl::GlobalDef().def("relax.ext.gna", GNACompiler);
181+
});
182+
183+
} // namespace contrib
184+
} // namespace relax
185+
186+
namespace target {
187+
188+
// Register GNA target kind
189+
TVM_REGISTER_TARGET_KIND("gna", kDLExtDev).set_default_keys({"gna"});
190+
191+
} // namespace target
192+
193+
} // namespace tvm

0 commit comments

Comments
 (0)