Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ComputeEngine Property for choosing Engine #2827

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions api/ccapi/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ enum class ExecutionMode {
VALIDATE /** Validate mode, label is necessary */
};

/**
* @brief Enumeration of layer compute engine
*/
enum LayerComputeEngine {
CPU, /**< CPU as the compute engine */
GPU, /**< GPU as the compute engine */
QNN, /**< QNN as the compute engine */
};

/**
* @brief Get the version of NNTrainer
*/
Expand Down
51 changes: 17 additions & 34 deletions api/ccapi/include/layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,6 @@ enum LayerType {
LAYER_UNKNOWN = ML_TRAIN_LAYER_TYPE_UNKNOWN /**< Unknown */
};

/**
* @brief Enumeration of layer compute engine
*/
enum LayerComputeEngine {
CPU, /**< CPU as the compute engine */
GPU, /**< GPU as the compute engine */
};

/**
* @class Layer Base class for layers
* @brief Base class for all layers
Expand Down Expand Up @@ -261,16 +253,14 @@ class Layer {
*/
std::unique_ptr<Layer>
createLayer(const LayerType &type,
const std::vector<std::string> &properties = {},
const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU);
const std::vector<std::string> &properties = {});

/**
* @brief Factory creator with constructor for layer
*/
std::unique_ptr<Layer>
createLayer(const std::string &type,
const std::vector<std::string> &properties = {},
const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU);
const std::vector<std::string> &properties = {});

/**
* @brief General Layer Factory function to register Layer
Expand Down Expand Up @@ -343,37 +333,33 @@ DivideLayer(const std::vector<std::string> &properties = {}) {
/**
* @brief Helper function to create fully connected layer
*/
inline std::unique_ptr<Layer> FullyConnected(
const std::vector<std::string> &properties = {},
const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) {
return createLayer(LayerType::LAYER_FC, properties, compute_engine);
inline std::unique_ptr<Layer>
FullyConnected(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_FC, properties);
}

/**
* @brief Helper function to create Swiglu layer
*/
inline std::unique_ptr<Layer>
Swiglu(const std::vector<std::string> &properties = {},
const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) {
return createLayer(LayerType::LAYER_SWIGLU, properties, compute_engine);
Swiglu(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_SWIGLU, properties);
}

/**
* @brief Helper function to create RMS normalization layer for GPU
*/
inline std::unique_ptr<Layer>
RMSNormCl(const std::vector<std::string> &properties = {},
const LayerComputeEngine &compute_engine = LayerComputeEngine::GPU) {
return createLayer(LayerType::LAYER_RMSNORM, properties, compute_engine);
RMSNormCl(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_RMSNORM, properties);
}

/**
* @brief Helper function to create Transpose layer
*/
inline std::unique_ptr<Layer>
Transpose(const std::vector<std::string> &properties = {},
const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) {
return createLayer(LayerType::LAYER_TRANSPOSE, properties, compute_engine);
Transpose(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_TRANSPOSE, properties);
}

/**
Expand Down Expand Up @@ -428,27 +414,24 @@ Flatten(const std::vector<std::string> &properties = {}) {
* @brief Helper function to create reshape layer
*/
inline std::unique_ptr<Layer>
Reshape(const std::vector<std::string> &properties = {},
const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) {
return createLayer(LayerType::LAYER_RESHAPE, properties, compute_engine);
Reshape(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_RESHAPE, properties);
}

/**
* @brief Helper function to create addition layer
*/
inline std::unique_ptr<Layer>
Addition(const std::vector<std::string> &properties = {},
const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) {
return createLayer(LayerType::LAYER_ADDITION, properties, compute_engine);
Addition(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_ADDITION, properties);
}

/**
* @brief Helper function to create concat layer
*/
inline std::unique_ptr<Layer>
Concat(const std::vector<std::string> &properties = {},
const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) {
return createLayer(LayerType::LAYER_CONCAT, properties, compute_engine);
Concat(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_CONCAT, properties);
}

/**
Expand Down
10 changes: 4 additions & 6 deletions api/ccapi/src/factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,16 @@ namespace ml {
namespace train {

std::unique_ptr<Layer> createLayer(const LayerType &type,
const std::vector<std::string> &properties,
const LayerComputeEngine &compute_engine) {
return nntrainer::createLayerNode(type, properties, compute_engine);
const std::vector<std::string> &properties) {
return nntrainer::createLayerNode(type, properties);
}

/**
* @brief Factory creator with constructor for layer
*/
std::unique_ptr<Layer> createLayer(const std::string &type,
const std::vector<std::string> &properties,
const LayerComputeEngine &compute_engine) {
return nntrainer::createLayerNode(type, properties, compute_engine);
const std::vector<std::string> &properties) {
return nntrainer::createLayerNode(type, properties);
}

std::unique_ptr<Optimizer>
Expand Down
24 changes: 23 additions & 1 deletion nntrainer/layers/common_properties.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <string>

#include <base_properties.h>
#include <common.h>
#include <connection.h>
#include <tensor.h>
#include <tensor_wrap_specs.h>
Expand Down Expand Up @@ -945,12 +946,33 @@ struct ActivationTypeInfo {
* @brief Activation Enumeration Information
*
*/
class Activation final : public EnumProperty<ActivationTypeInfo> {
class Activation final
: public EnumProperty<nntrainer::props::ActivationTypeInfo> {
public:
using prop_tag = enum_class_prop_tag;
static constexpr const char *key = "activation";
};

/**
* @brief Enumeration of Run Engine type
*/
struct ComputeEngineTypeInfo {
using Enum = ml::train::LayerComputeEngine;
static constexpr std::initializer_list<Enum> EnumList = {Enum::CPU, Enum::GPU,
Enum::QNN};
static constexpr const char *EnumStr[] = {"cpu", "gpu", "qnn"};
};

/**
* @brief ComputeEngine Enumeration Information
*
*/
class ComputeEngine final : public EnumProperty<ComputeEngineTypeInfo> {
public:
using prop_tag = enum_class_prop_tag;
static constexpr const char *key = "engine";
};

/**
* @brief HiddenStateActivation Enumeration Information
*
Expand Down
75 changes: 53 additions & 22 deletions nntrainer/layers/layer_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,20 +130,49 @@ class SharedFrom : public Name {
*/
LayerNode::~LayerNode() = default;

/**
* @brief get the compute engine property from property string vector
* : default is CPU
* @return LayerComputeEngine Enum : CPU, GPU, QNN
*
*/
ml::train::LayerComputeEngine
getComputeEngine(const std::vector<std::string> &props) {
for (auto &prop : props) {
std::string key, value;
int status = nntrainer::getKeyValue(prop, key, value);
if (nntrainer::istrequal(key, "engine")) {
constexpr const auto data =
std::data(props::ComputeEngineTypeInfo::EnumList);
for (uint i = 0; i < props::ComputeEngineTypeInfo::EnumList.size(); ++i) {
if (nntrainer::istrequal(value.c_str(),
props::ComputeEngineTypeInfo::EnumStr[i])) {
return data[i];
}
}
}
}

return ml::train::LayerComputeEngine::CPU;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this function is called frequently and may affect the latency, you may use:

if (key[0] == 'q')
  return  QNN;
if (key[0] == 'g')
  return GPU;
return CPU;

If you want a generalized code (wtih a bit of inefficiency), use EnumStr

for (i = 0; i < arraysize(EnumStr); i++) {
  if (istrequal(key, EnumStr[i])
    return EnumList[i];
}

}

/**
* @brief Layer factory creator with constructor
*/
std::unique_ptr<LayerNode>
createLayerNode(const ml::train::LayerType &type,
const std::vector<std::string> &properties,
const ml::train::LayerComputeEngine &compute_engine) {
const std::vector<std::string> &properties) {

if (getComputeEngine(properties) == ml::train::LayerComputeEngine::GPU) {
#ifdef ENABLE_OPENCL
if (compute_engine == ml::train::LayerComputeEngine::GPU) {
auto &cc = nntrainer::ClContext::Global();
return createLayerNode(cc.createObject<nntrainer::Layer>(type), properties,
compute_engine);
}
return createLayerNode(cc.createObject<nntrainer::Layer>(type), properties);
#else
throw std::invalid_argument(
"opencl layer creation without enable-opencl option");
#endif
}

auto &ac = nntrainer::AppContext::Global();
return createLayerNode(ac.createObject<nntrainer::Layer>(type), properties);
}
Expand All @@ -153,15 +182,18 @@ createLayerNode(const ml::train::LayerType &type,
*/
std::unique_ptr<LayerNode>
createLayerNode(const std::string &type,
const std::vector<std::string> &properties,
const ml::train::LayerComputeEngine &compute_engine) {
const std::vector<std::string> &properties) {

if (getComputeEngine(properties) == ml::train::LayerComputeEngine::GPU) {
#ifdef ENABLE_OPENCL
if (compute_engine == ml::train::LayerComputeEngine::GPU) {
auto &cc = nntrainer::ClContext::Global();
return createLayerNode(cc.createObject<nntrainer::Layer>(type), properties,
compute_engine);
}
return createLayerNode(cc.createObject<nntrainer::Layer>(type), properties);
#else
throw std::invalid_argument(
"opencl layer creation without enable-opencl option");
#endif
}

auto &ac = nntrainer::AppContext::Global();
return createLayerNode(ac.createObject<nntrainer::Layer>(type), properties);
}
Expand All @@ -171,16 +203,11 @@ createLayerNode(const std::string &type,
*/
std::unique_ptr<LayerNode>
createLayerNode(std::unique_ptr<nntrainer::Layer> &&layer,
const std::vector<std::string> &properties,
const ml::train::LayerComputeEngine &compute_engine) {
const std::vector<std::string> &properties) {
auto lnode = std::make_unique<LayerNode>(std::move(layer));

lnode->setProperty(properties);

if (compute_engine == ml::train::LayerComputeEngine::GPU) {
lnode->setComputeEngine(compute_engine);
}

return lnode;
}

Expand All @@ -192,10 +219,10 @@ LayerNode::LayerNode(std::unique_ptr<nntrainer::Layer> &&l) :

output_connections(),
run_context(nullptr),
layer_node_props(
new PropsType(props::Name(), props::Distribute(), props::Trainable(), {},
{}, props::SharedFrom(), props::ClipGradByGlobalNorm(),
props::Packed(), props::LossScaleForMixed())),
layer_node_props(new PropsType(
props::Name(), props::Distribute(), props::Trainable(), {}, {},
props::SharedFrom(), props::ClipGradByGlobalNorm(), props::Packed(),
props::LossScaleForMixed(), props::ComputeEngine())),
layer_node_props_realization(
new RealizationPropsType(props::Flatten(), props::Activation())),
loss(new props::Loss()),
Expand Down Expand Up @@ -670,6 +697,10 @@ InitLayerContext LayerNode::finalize(const std::vector<TensorDim> &input_dims,
if (!std::get<props::LossScaleForMixed>(*layer_node_props).empty())
loss_scale = std::get<props::LossScaleForMixed>(*layer_node_props).get();

if (!std::get<props::ComputeEngine>(*layer_node_props).empty()) {
compute_engine = std::get<props::ComputeEngine>(*layer_node_props).get();
}

if (!std::get<props::Packed>(*layer_node_props).empty()) {
bool isPacked = std::get<props::Packed>(*layer_node_props);
if (!isPacked) {
Expand Down
24 changes: 10 additions & 14 deletions nntrainer/layers/layer_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class InputConnection;
class ClipGradByGlobalNorm;
class Packed;
class LossScaleForMixed;
class ComputeEngine;
} // namespace props

/**
Expand Down Expand Up @@ -994,11 +995,12 @@ will also contain the properties of the layer. The properties will be copied
upon final creation. Editing properties of the layer after init will not the
properties in the context/graph unless intended. */

using PropsType = std::tuple<props::Name, props::Distribute, props::Trainable,
std::vector<props::InputConnection>,
std::vector<props::InputShape>,
props::SharedFrom, props::ClipGradByGlobalNorm,
props::Packed, props::LossScaleForMixed>;
using PropsType =
std::tuple<props::Name, props::Distribute, props::Trainable,
std::vector<props::InputConnection>,
std::vector<props::InputShape>, props::SharedFrom,
props::ClipGradByGlobalNorm, props::Packed,
props::LossScaleForMixed, props::ComputeEngine>;

using RealizationPropsType = std::tuple<props::Flatten, props::Activation>;
/** these realization properties results in addition of new layers, hence
Expand Down Expand Up @@ -1070,9 +1072,7 @@ properties in the context/graph unless intended. */
*/
std::unique_ptr<LayerNode>
createLayerNode(const ml::train::LayerType &type,
const std::vector<std::string> &properties = {},
const ml::train::LayerComputeEngine &compute_engine =
ml::train::LayerComputeEngine::CPU);
const std::vector<std::string> &properties = {});

/**
* @brief LayerNode creator with constructor
Expand All @@ -1082,9 +1082,7 @@ createLayerNode(const ml::train::LayerType &type,
*/
std::unique_ptr<LayerNode>
createLayerNode(const std::string &type,
const std::vector<std::string> &properties = {},
const ml::train::LayerComputeEngine &compute_engine =
ml::train::LayerComputeEngine::CPU);
const std::vector<std::string> &properties = {});

/**
* @brief LayerNode creator with constructor
Expand All @@ -1095,9 +1093,7 @@ createLayerNode(const std::string &type,
*/
std::unique_ptr<LayerNode>
createLayerNode(std::unique_ptr<nntrainer::Layer> &&layer,
const std::vector<std::string> &properties,
const ml::train::LayerComputeEngine &compute_engine =
ml::train::LayerComputeEngine::CPU);
const std::vector<std::string> &properties);

} // namespace nntrainer
#endif // __LAYER_NODE_H__
2 changes: 1 addition & 1 deletion nntrainer/utils/node_exporter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void Exporter::saveTflResult(
std::vector<props::InputConnection>,
std::vector<props::InputShape>, props::SharedFrom,
props::ClipGradByGlobalNorm, props::Packed,
props::LossScaleForMixed> &props,
props::LossScaleForMixed, props::ComputeEngine> &props,
const LayerNode *self) {
createIfNull(tf_node);
tf_node->setLayerNode(*self);
Expand Down
2 changes: 1 addition & 1 deletion nntrainer/utils/node_exporter.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ void Exporter::saveTflResult(
std::vector<props::InputConnection>,
std::vector<props::InputShape>, props::SharedFrom,
props::ClipGradByGlobalNorm, props::Packed,
props::LossScaleForMixed> &props,
props::LossScaleForMixed, props::ComputeEngine> &props,
const LayerNode *self);

class BatchNormalizationLayer;
Expand Down
Loading
Loading