Skip to content

Commit 2c036cc

Browse files
committed
[Backend][Relax] Add Intel GNA backend with CPU emulation for CI
This commit introduces the Intel GNA (Gaussian Neural Accelerator) backend for TVM's Relax IR with a clean separation between hardware and emulation runtimes to enable CI testing without GNA hardware. Key components: - GNA codegen for Relax IR (graph partitioning and code generation) - Hardware runtime (gna_json_runtime.cc) for systems with GNA SDK - CPU emulation runtime (gna_json_runtime_emulation.cc) for CI/testing - Conditional CMake build based on GNA SDK availability - Pattern registry for dense, conv1d, and relu operations - Comprehensive test suite Architecture decisions: - Clean separation: Hardware and emulation in separate files (no mocking) - CI-friendly: Emulation runtime has no GNA SDK dependencies - Follows OpenVINO's Software Emulation Mode pattern - Same API surface for both runtime implementations The emulation runtime provides simplified reference implementations sufficient for testing graph partitioning and codegen correctness. For production CPU inference, use TVM's standard CPU backend. This backend serves as a stepping stone toward Intel NPU support and provides a minimal example for Relax backend development.
1 parent 789e0b8 commit 2c036cc

File tree

9 files changed

+1105
-0
lines changed

9 files changed

+1105
-0
lines changed

CMakeLists.txt

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ tvm_option(USE_BLAS "The blas library to be linked" none)
8888
tvm_option(USE_AMX "Enable Intel AMX" OFF)
8989
tvm_option(USE_MKL "MKL root path when use MKL blas" OFF)
9090
tvm_option(USE_DNNL "Enable DNNL codegen" OFF)
91+
tvm_option(USE_GNA "Enable Intel GNA codegen" OFF)
9192
tvm_option(USE_CUDNN "Build with cuDNN" OFF)
9293
tvm_option(USE_CUBLAS "Build with cuBLAS" OFF)
9394
tvm_option(USE_NVTX "Build with NVTX" OFF)
@@ -327,6 +328,8 @@ tvm_file_glob(GLOB DATATYPE_SRCS src/target/datatype/*.cc)
327328
list(APPEND COMPILER_SRCS ${DATATYPE_SRCS})
328329
list(APPEND COMPILER_SRCS "src/target/datatype/myfloat/myfloat.cc")
329330

331+
list(APPEND COMPILER_SRCS "src/relax/backend/contrib/gna/codegen.cc")
332+
330333
tvm_file_glob(GLOB RUNTIME_SRCS
331334
src/runtime/*.cc
332335
src/runtime/vm/*.cc
@@ -389,6 +392,53 @@ if (USE_CUDA AND USE_NVSHMEM)
389392
list(APPEND RUNTIME_SRCS ${RUNTIME_NVSHMEM_SRCS})
390393
endif()
391394

395+
if(USE_GNA)
396+
message(STATUS "Build with Intel GNA...")
397+
398+
# Try to find GNA SDK headers
399+
find_path(GNA_INCLUDE_DIR gna2-api.h HINTS ../gna/src/gna-api)
400+
401+
if(GNA_INCLUDE_DIR)
402+
# Full hardware support with SDK
403+
message(STATUS "Found GNA headers at ${GNA_INCLUDE_DIR} - building with hardware support")
404+
list(APPEND RUNTIME_SRCS src/runtime/contrib/gna/gna_json_runtime.cc)
405+
else()
406+
# CPU emulation only (for CI and development without SDK)
407+
message(STATUS "GNA headers not found - building with CPU emulation only (suitable for CI)")
408+
list(APPEND RUNTIME_SRCS src/runtime/contrib/gna/gna_json_runtime_emulation.cc)
409+
set(GNA_EMULATION_ONLY ON)
410+
endif()
411+
412+
find_path(GNA_LIB_DIR NAMES gna.dll gna.so libgna.so HINTS
413+
../gna/bin/gna-lib/WIN-DEBUG/x64
414+
../gna/bin/gna-lib/WIN-RELEASE/x64
415+
../gna/bin/gna-lib/LNX-DEBUG/x64
416+
../gna/bin/gna-lib/LNX-RELEASE/x64
417+
../gna/build/src/gna-lib)
418+
419+
if(GNA_LIB_DIR)
420+
message(STATUS "Found GNA library directory: ${GNA_LIB_DIR}")
421+
else()
422+
message(WARNING "GNA library not found. Build GNA first: cd ../gna && mkdir -p build && cd build && cmake .. && make")
423+
endif()
424+
425+
if(NOT GNA_EMULATION_ONLY)
426+
include_directories(${GNA_INCLUDE_DIR})
427+
if(GNA_LIB_DIR)
428+
link_directories(${GNA_LIB_DIR})
429+
if(WIN32)
430+
list(APPEND TVM_RUNTIME_LINKER_LIBS gna.lib)
431+
else()
432+
list(APPEND TVM_RUNTIME_LINKER_LIBS gna)
433+
endif()
434+
endif()
435+
endif()
436+
437+
# Always include codegen (doesn't require GNA SDK)
438+
tvm_file_glob(GLOB GNA_CODEGEN_SRCS src/relax/backend/contrib/gna/*.cc)
439+
list(APPEND COMPILER_SRCS ${GNA_CODEGEN_SRCS})
440+
endif()
441+
392442
if(USE_ROCM AND USE_RCCL)
393443
message(STATUS "Build with RCCL...")
394444
find_rccl(${USE_RCCL})

cmake/modules/LibInfo.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ function(add_lib_info src_file)
129129
TVM_INFO_USE_NVSHMEM="${USE_NVSHMEM}"
130130
TVM_INFO_USE_NNAPI_CODEGEN="${USE_NNAPI_CODEGEN}"
131131
TVM_INFO_USE_NNAPI_RUNTIME="${USE_NNAPI_RUNTIME}"
132+
TVM_INFO_USE_GNA="${USE_GNA}"
132133
TVM_INFO_BACKTRACE_ON_SEGFAULT="${BACKTRACE_ON_SEGFAULT}"
133134
)
134135

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Pattern table and codegen for GNA"""
18+
19+
from . import gna # noqa: F401
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Pattern table for GNA backend"""
18+
19+
from tvm.relax.dpl.pattern import is_op, wildcard
20+
from tvm.relax.transform import PatternCheckContext
21+
22+
from ...pattern_registry import register_patterns
23+
24+
25+
def _check_default(context: PatternCheckContext) -> bool: # pylint: disable=unused-argument
26+
return True
27+
28+
29+
def linear_patterns():
30+
"""
31+
Returns a list of linear/dense patterns in GNA BYOC backend.
32+
"""
33+
34+
def _make_linear_pattern():
35+
input0 = wildcard()
36+
weight = wildcard()
37+
out = is_op("relax.matmul")(input0, weight)
38+
annotations = {"input": input0, "weight": weight, "root": out}
39+
return out, annotations
40+
41+
def _linear_pattern(pattern_name):
42+
return (pattern_name, *_make_linear_pattern(), _check_default)
43+
44+
return [_linear_pattern("gna.dense")]
45+
46+
47+
def conv1d_patterns():
48+
"""
49+
Returns a list of conv1d patterns in GNA BYOC backend.
50+
"""
51+
52+
def _make_conv1d_pattern():
53+
input0 = wildcard()
54+
weight = wildcard()
55+
out = is_op("relax.nn.conv1d")(input0, weight)
56+
annotations = {"input": input0, "weight": weight, "root": out}
57+
return out, annotations
58+
59+
def _conv1d_pattern(pattern_name):
60+
return (pattern_name, *_make_conv1d_pattern(), _check_default)
61+
62+
return [_conv1d_pattern("gna.conv1d")]
63+
64+
65+
def activation_patterns():
66+
"""
67+
Returns a list of activation patterns in GNA BYOC backend.
68+
"""
69+
70+
def _make_activation_pattern():
71+
input0 = wildcard()
72+
out = is_op("relax.nn.relu")(input0)
73+
annotations = {"input": input0, "root": out}
74+
return out, annotations
75+
76+
def _activation_pattern(pattern_name):
77+
return (pattern_name, *_make_activation_pattern(), _check_default)
78+
79+
return [_activation_pattern("gna.relu")]
80+
81+
82+
register_patterns(
83+
[
84+
*linear_patterns(),
85+
*conv1d_patterns(),
86+
*activation_patterns(),
87+
]
88+
)
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file src/relax/backend/contrib/gna/codegen.cc
22+
* \brief Implementation of the GNA JSON serializer.
23+
*/
24+
#include <tvm/ffi/reflection/registry.h>
25+
#include <tvm/ir/module.h>
26+
#include <tvm/relax/expr.h>
27+
#include <tvm/relax/struct_info.h>
28+
#include <tvm/runtime/module.h>
29+
#include <tvm/runtime/packed_func.h>
30+
#include <tvm/tir/expr.h>
31+
32+
#include <string>
33+
34+
#include "../codegen_json/codegen_json.h"
35+
#include "../utils.h"
36+
37+
namespace tvm {
38+
namespace relax {
39+
namespace contrib {
40+
41+
using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
42+
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
43+
using JSONSerializer = backend::contrib::JSONSerializer;
44+
using backend::contrib::NodeEntries;
45+
46+
class GNAJSONSerializer : public JSONSerializer {
47+
public:
48+
GNAJSONSerializer(Map<Constant, String> constant_names, Map<Var, Expr> bindings)
49+
: JSONSerializer(constant_names), bindings_(bindings) {}
50+
51+
using JSONSerializer::VisitExpr_;
52+
53+
NodeEntries VisitExpr_(const CallNode* call_node) final {
54+
const auto* fn_var = call_node->op.as<VarNode>();
55+
ICHECK(fn_var);
56+
const auto fn = Downcast<Function>(bindings_[GetRef<Var>(fn_var)]);
57+
ICHECK(fn.defined()) << "Expects the callee to be a function.";
58+
59+
auto composite_opt = fn->GetAttr<String>(attr::kComposite);
60+
ICHECK(composite_opt.has_value()) << "Only composite functions are supported.";
61+
62+
std::string composite_name = composite_opt.value();
63+
64+
NodeEntries inputs;
65+
for (const auto& arg : call_node->args) {
66+
auto res = VisitExpr(arg);
67+
inputs.insert(inputs.end(), res.begin(), res.end());
68+
}
69+
70+
auto node = std::make_shared<JSONGraphNode>(composite_name, /* name_ */
71+
"kernel", /* op_type_ */
72+
inputs, 1 /* num_outputs_ */);
73+
74+
const CallNode* root_call = nullptr;
75+
if (composite_name.find("gna.dense") != std::string::npos) {
76+
root_call = backend::GetOpInFunction(fn, "relax.matmul");
77+
} else if (composite_name.find("gna.conv1d") != std::string::npos) {
78+
root_call = backend::GetOpInFunction(fn, "relax.nn.conv1d");
79+
} else if (composite_name.find("gna.relu") != std::string::npos) {
80+
root_call = backend::GetOpInFunction(fn, "relax.nn.relu");
81+
} else {
82+
LOG(FATAL) << "Unimplemented GNA pattern: " << composite_name;
83+
}
84+
85+
SetCallNodeAttribute(node, root_call);
86+
return AddNode(node, GetRef<Expr>(call_node));
87+
}
88+
89+
private:
90+
/*! \brief The bindings to look up composite functions. */
91+
Map<Var, Expr> bindings_;
92+
93+
void SetCallNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* call) {
94+
// First call the base implementation to extract standard attributes
95+
JSONSerializer::SetCallNodeAttribute(node, call);
96+
97+
// Add GNA-specific attributes based on the operation
98+
if (call && call->op.as<OpNode>()) {
99+
auto op = Downcast<Op>(call->op);
100+
std::string op_name = op->name;
101+
102+
// Extract shape information from struct_info
103+
if (!call->args.empty()) {
104+
StructInfo input_sinfo = GetStructInfo(call->args[0]);
105+
if (const auto* tensor_sinfo = input_sinfo.as<TensorStructInfoNode>()) {
106+
if (tensor_sinfo->shape.defined()) {
107+
std::vector<std::string> shape_strs;
108+
ShapeExpr shape = Downcast<ShapeExpr>(tensor_sinfo->shape.value());
109+
for (const PrimExpr& dim : shape->values) {
110+
if (const auto* int_imm = dim.as<tvm::tir::IntImmNode>()) {
111+
shape_strs.push_back(std::to_string(int_imm->value));
112+
} else {
113+
shape_strs.push_back("-1");
114+
}
115+
}
116+
std::vector<dmlc::any> shape_attr;
117+
shape_attr.emplace_back(shape_strs);
118+
node->SetAttr("input_shape", shape_attr);
119+
}
120+
121+
std::vector<std::string> dtype_strs{tensor_sinfo->dtype.code() == kDLFloat ? "float32"
122+
: "int32"};
123+
std::vector<dmlc::any> dtype_attr;
124+
dtype_attr.emplace_back(dtype_strs);
125+
node->SetAttr("input_dtype", dtype_attr);
126+
}
127+
}
128+
129+
if (op_name == "relax.nn.conv1d") {
130+
if (call->attrs.defined()) {
131+
std::vector<std::string> op_attrs{"conv1d_op"};
132+
std::vector<dmlc::any> op_attr;
133+
op_attr.emplace_back(op_attrs);
134+
node->SetAttr("gna_op_type", op_attr);
135+
}
136+
} else if (op_name == "relax.matmul") {
137+
std::vector<std::string> op_attrs{"dense_op"};
138+
std::vector<dmlc::any> op_attr;
139+
op_attr.emplace_back(op_attrs);
140+
node->SetAttr("gna_op_type", op_attr);
141+
} else if (op_name == "relax.nn.relu") {
142+
std::vector<std::string> op_attrs{"activation_op"};
143+
std::vector<dmlc::any> op_attr;
144+
op_attr.emplace_back(op_attrs);
145+
node->SetAttr("gna_op_type", op_attr);
146+
}
147+
}
148+
}
149+
};
150+
151+
/*!
152+
* \brief Create a GNA JSON runtime module.
153+
* \param functions The functions to be compiled.
154+
* \param unused Unused config options.
155+
* \param constant_names The constant names to be used.
156+
* \return Array of runtime modules.
157+
*/
158+
Array<runtime::Module> GNACompiler(Array<Function> functions, Map<String, ffi::Any> /*unused*/,
159+
Map<Constant, String> constant_names) {
160+
Array<runtime::Module> compiled_functions;
161+
162+
for (const auto& func : functions) {
163+
GNAJSONSerializer serializer(constant_names, AnalyzeVar2Value(func));
164+
serializer.serialize(func);
165+
auto graph_json = serializer.GetJSON();
166+
auto constant_names_used = serializer.GetConstantNames();
167+
168+
const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.GNAJSONRuntimeCreate");
169+
auto func_name = GetExtSymbol(func);
170+
compiled_functions.push_back(
171+
pf(func_name, graph_json, constant_names_used).cast<runtime::Module>());
172+
}
173+
174+
return compiled_functions;
175+
}
176+
177+
// Register the external codegen entrypoint via FFI reflection (new TVM registry)
178+
TVM_FFI_STATIC_INIT_BLOCK({
179+
namespace refl = tvm::ffi::reflection;
180+
refl::GlobalDef().def("relax.ext.gna", GNACompiler);
181+
});
182+
183+
} // namespace contrib
184+
} // namespace relax
185+
186+
namespace target {
187+
188+
// Register GNA target kind
189+
TVM_REGISTER_TARGET_KIND("gna", kDLExtDev).set_default_keys({"gna"});
190+
191+
} // namespace target
192+
193+
} // namespace tvm

0 commit comments

Comments
 (0)