Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ tvm_option(USE_BLAS "The blas library to be linked" none)
tvm_option(USE_AMX "Enable Intel AMX" OFF)
tvm_option(USE_MKL "MKL root path when use MKL blas" OFF)
tvm_option(USE_DNNL "Enable DNNL codegen" OFF)
tvm_option(USE_GNA_CODEGEN "Build with Intel GNA Codegen support" OFF)
tvm_option(USE_GNA_RUNTIME "Build with Intel GNA runtime" OFF)
tvm_option(USE_CUDNN "Build with cuDNN" OFF)
tvm_option(USE_CUBLAS "Build with cuBLAS" OFF)
tvm_option(USE_NVTX "Build with NVTX" OFF)
Expand Down Expand Up @@ -327,6 +329,10 @@ tvm_file_glob(GLOB DATATYPE_SRCS src/target/datatype/*.cc)
list(APPEND COMPILER_SRCS ${DATATYPE_SRCS})
list(APPEND COMPILER_SRCS "src/target/datatype/myfloat/myfloat.cc")

if(USE_GNA_CODEGEN)
list(APPEND COMPILER_SRCS "src/relax/backend/contrib/gna/codegen.cc")
endif()

tvm_file_glob(GLOB RUNTIME_SRCS
src/runtime/*.cc
src/runtime/vm/*.cc
Expand Down Expand Up @@ -389,6 +395,49 @@ if (USE_CUDA AND USE_NVSHMEM)
list(APPEND RUNTIME_SRCS ${RUNTIME_NVSHMEM_SRCS})
endif()

if(USE_GNA_RUNTIME)
message(STATUS "Build with Intel GNA runtime...")

# Try to find GNA SDK headers
find_path(GNA_INCLUDE_DIR gna2-api.h HINTS ../gna/src/gna-api)

if(GNA_INCLUDE_DIR)
# Full hardware support with SDK
message(STATUS "Found GNA headers at ${GNA_INCLUDE_DIR} - building with hardware support")
list(APPEND RUNTIME_SRCS src/runtime/contrib/gna/gna_json_runtime.cc)
else()
# CPU emulation only (for CI and development without SDK)
message(STATUS "GNA headers not found - building with CPU emulation only (suitable for CI)")
list(APPEND RUNTIME_SRCS src/runtime/contrib/gna/gna_json_runtime_emulation.cc)
set(GNA_EMULATION_ONLY ON)
endif()

find_path(GNA_LIB_DIR NAMES gna.dll gna.so libgna.so HINTS
../gna/bin/gna-lib/WIN-DEBUG/x64
../gna/bin/gna-lib/WIN-RELEASE/x64
../gna/bin/gna-lib/LNX-DEBUG/x64
../gna/bin/gna-lib/LNX-RELEASE/x64
../gna/build/src/gna-lib)

if(GNA_LIB_DIR)
message(STATUS "Found GNA library directory: ${GNA_LIB_DIR}")
else()
message(WARNING "GNA library not found. Build GNA first: cd ../gna && mkdir -p build && cd build && cmake .. && make")
endif()

if(NOT GNA_EMULATION_ONLY)
include_directories(${GNA_INCLUDE_DIR})
if(GNA_LIB_DIR)
link_directories(${GNA_LIB_DIR})
if(WIN32)
list(APPEND TVM_RUNTIME_LINKER_LIBS gna.lib)
else()
list(APPEND TVM_RUNTIME_LINKER_LIBS gna)
endif()
endif()
endif()
endif()

if(USE_ROCM AND USE_RCCL)
message(STATUS "Build with RCCL...")
find_rccl(${USE_RCCL})
Expand Down
2 changes: 2 additions & 0 deletions cmake/modules/LibInfo.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ function(add_lib_info src_file)
TVM_INFO_USE_NVSHMEM="${USE_NVSHMEM}"
TVM_INFO_USE_NNAPI_CODEGEN="${USE_NNAPI_CODEGEN}"
TVM_INFO_USE_NNAPI_RUNTIME="${USE_NNAPI_RUNTIME}"
TVM_INFO_USE_GNA_CODEGEN="${USE_GNA_CODEGEN}"
TVM_INFO_USE_GNA_RUNTIME="${USE_GNA_RUNTIME}"
TVM_INFO_BACKTRACE_ON_SEGFAULT="${BACKTRACE_ON_SEGFAULT}"
)

Expand Down
19 changes: 19 additions & 0 deletions python/tvm/relax/backend/contrib/gna/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Pattern table and codegen for GNA"""

from . import gna # noqa: F401
88 changes: 88 additions & 0 deletions python/tvm/relax/backend/contrib/gna/gna.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Pattern table for GNA backend"""

from tvm.relax.dpl.pattern import is_op, wildcard
from tvm.relax.transform import PatternCheckContext

from ...pattern_registry import register_patterns


def _check_default(context: PatternCheckContext) -> bool: # pylint: disable=unused-argument
return True


def linear_patterns():
"""
Returns a list of linear/dense patterns in GNA BYOC backend.
"""

def _make_linear_pattern():
input0 = wildcard()
weight = wildcard()
out = is_op("relax.matmul")(input0, weight)
annotations = {"input": input0, "weight": weight, "root": out}
return out, annotations

def _linear_pattern(pattern_name):
return (pattern_name, *_make_linear_pattern(), _check_default)

return [_linear_pattern("gna.dense")]


def conv1d_patterns():
"""
Returns a list of conv1d patterns in GNA BYOC backend.
"""

def _make_conv1d_pattern():
input0 = wildcard()
weight = wildcard()
out = is_op("relax.nn.conv1d")(input0, weight)
annotations = {"input": input0, "weight": weight, "root": out}
return out, annotations

def _conv1d_pattern(pattern_name):
return (pattern_name, *_make_conv1d_pattern(), _check_default)

return [_conv1d_pattern("gna.conv1d")]


def activation_patterns():
"""
Returns a list of activation patterns in GNA BYOC backend.
"""

def _make_activation_pattern():
input0 = wildcard()
out = is_op("relax.nn.relu")(input0)
annotations = {"input": input0, "root": out}
return out, annotations

def _activation_pattern(pattern_name):
return (pattern_name, *_make_activation_pattern(), _check_default)

return [_activation_pattern("gna.relu")]


register_patterns(
[
*linear_patterns(),
*conv1d_patterns(),
*activation_patterns(),
]
)
193 changes: 193 additions & 0 deletions src/relax/backend/contrib/gna/codegen.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relax/backend/contrib/gna/codegen.cc
* \brief Implementation of the GNA JSON serializer.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/module.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/struct_info.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/tir/expr.h>

#include <string>

#include "../codegen_json/codegen_json.h"
#include "../utils.h"

namespace tvm {
namespace relax {
namespace contrib {

using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
using JSONSerializer = backend::contrib::JSONSerializer;
using backend::contrib::NodeEntries;

class GNAJSONSerializer : public JSONSerializer {
public:
GNAJSONSerializer(Map<Constant, String> constant_names, Map<Var, Expr> bindings)
: JSONSerializer(constant_names), bindings_(bindings) {}

using JSONSerializer::VisitExpr_;

NodeEntries VisitExpr_(const CallNode* call_node) final {
const auto* fn_var = call_node->op.as<VarNode>();
ICHECK(fn_var);
const auto fn = Downcast<Function>(bindings_[GetRef<Var>(fn_var)]);
ICHECK(fn.defined()) << "Expects the callee to be a function.";

auto composite_opt = fn->GetAttr<String>(attr::kComposite);
ICHECK(composite_opt.has_value()) << "Only composite functions are supported.";

std::string composite_name = composite_opt.value();

NodeEntries inputs;
for (const auto& arg : call_node->args) {
auto res = VisitExpr(arg);
inputs.insert(inputs.end(), res.begin(), res.end());
}

auto node = std::make_shared<JSONGraphNode>(composite_name, /* name_ */
"kernel", /* op_type_ */
inputs, 1 /* num_outputs_ */);

const CallNode* root_call = nullptr;
if (composite_name.find("gna.dense") != std::string::npos) {
root_call = backend::GetOpInFunction(fn, "relax.matmul");
} else if (composite_name.find("gna.conv1d") != std::string::npos) {
root_call = backend::GetOpInFunction(fn, "relax.nn.conv1d");
} else if (composite_name.find("gna.relu") != std::string::npos) {
root_call = backend::GetOpInFunction(fn, "relax.nn.relu");
} else {
LOG(FATAL) << "Unimplemented GNA pattern: " << composite_name;
}

SetCallNodeAttribute(node, root_call);
return AddNode(node, GetRef<Expr>(call_node));
}

private:
/*! \brief The bindings to look up composite functions. */
Map<Var, Expr> bindings_;

void SetCallNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* call) {
// First call the base implementation to extract standard attributes
JSONSerializer::SetCallNodeAttribute(node, call);

// Add GNA-specific attributes based on the operation
if (call && call->op.as<OpNode>()) {
auto op = Downcast<Op>(call->op);
std::string op_name = op->name;

// Extract shape information from struct_info
if (!call->args.empty()) {
StructInfo input_sinfo = GetStructInfo(call->args[0]);
if (const auto* tensor_sinfo = input_sinfo.as<TensorStructInfoNode>()) {
if (tensor_sinfo->shape.defined()) {
std::vector<std::string> shape_strs;
ShapeExpr shape = Downcast<ShapeExpr>(tensor_sinfo->shape.value());
for (const PrimExpr& dim : shape->values) {
if (const auto* int_imm = dim.as<tvm::tir::IntImmNode>()) {
shape_strs.push_back(std::to_string(int_imm->value));
} else {
shape_strs.push_back("-1");
}
}
std::vector<dmlc::any> shape_attr;
shape_attr.emplace_back(shape_strs);
node->SetAttr("input_shape", shape_attr);
}

std::vector<std::string> dtype_strs{tensor_sinfo->dtype.code() == kDLFloat ? "float32"
: "int32"};
std::vector<dmlc::any> dtype_attr;
dtype_attr.emplace_back(dtype_strs);
node->SetAttr("input_dtype", dtype_attr);
}
}

if (op_name == "relax.nn.conv1d") {
if (call->attrs.defined()) {
std::vector<std::string> op_attrs{"conv1d_op"};
std::vector<dmlc::any> op_attr;
op_attr.emplace_back(op_attrs);
node->SetAttr("gna_op_type", op_attr);
}
} else if (op_name == "relax.matmul") {
std::vector<std::string> op_attrs{"dense_op"};
std::vector<dmlc::any> op_attr;
op_attr.emplace_back(op_attrs);
node->SetAttr("gna_op_type", op_attr);
} else if (op_name == "relax.nn.relu") {
std::vector<std::string> op_attrs{"activation_op"};
std::vector<dmlc::any> op_attr;
op_attr.emplace_back(op_attrs);
node->SetAttr("gna_op_type", op_attr);
}
}
}
};

/*!
* \brief Create a GNA JSON runtime module.
* \param functions The functions to be compiled.
* \param unused Unused config options.
* \param constant_names The constant names to be used.
* \return Array of runtime modules.
*/
Array<runtime::Module> GNACompiler(Array<Function> functions, Map<String, ffi::Any> /*unused*/,
Map<Constant, String> constant_names) {
Array<runtime::Module> compiled_functions;

for (const auto& func : functions) {
GNAJSONSerializer serializer(constant_names, AnalyzeVar2Value(func));
serializer.serialize(func);
auto graph_json = serializer.GetJSON();
auto constant_names_used = serializer.GetConstantNames();

const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.GNAJSONRuntimeCreate");
auto func_name = GetExtSymbol(func);
compiled_functions.push_back(
pf(func_name, graph_json, constant_names_used).cast<runtime::Module>());
}

return compiled_functions;
}

// Register the external codegen entrypoint via FFI reflection (new TVM registry)
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.ext.gna", GNACompiler);
});

} // namespace contrib
} // namespace relax

namespace target {

// Register GNA target kind
TVM_REGISTER_TARGET_KIND("gna", kDLExtDev).set_default_keys({"gna"});

} // namespace target

} // namespace tvm
Loading