Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ GPU ] split kernel registration from forwarding function in rmsnorm_layer_cl #2804

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions nntrainer/cl_context.cpp
Original file line number Diff line number Diff line change
@@ -53,9 +53,11 @@ static void add_default_object(ClContext &cc) {
ml::train::LayerType::LAYER_RESHAPE);
}

// @todo rmsnormlayercl also needs to be updated.
cc.registerFactory(nntrainer::createLayer<RMSNormLayerCl>,
RMSNormLayerCl::type, ml::train::LayerType::LAYER_RMSNORM);
if (RMSNormLayerCl::registerClKernels()) {
cc.registerFactory(nntrainer::createLayer<RMSNormLayerCl>,
RMSNormLayerCl::type,
ml::train::LayerType::LAYER_RMSNORM);
}

if (ConcatLayerCl::registerClKernels()) {
cc.registerFactory(nntrainer::createLayer<ConcatLayerCl>,
16 changes: 8 additions & 8 deletions nntrainer/layers/bn_layer.cpp
Original file line number Diff line number Diff line change
@@ -50,10 +50,10 @@ enum BNParams {
BatchNormalizationLayer::BatchNormalizationLayer() :
Layer(),
divider(0),
bn_props(props::Epsilon(), props::BNPARAMS_MU_INIT(),
props::BNPARAMS_VAR_INIT(), props::BNPARAMS_BETA_INIT(),
props::BNPARAMS_GAMMA_INIT(), props::Momentum(), props::Axis(),
props::WeightDecay(), props::BiasDecay()) {
bn_props(props::Epsilon(), props::MuInitializer(), props::VarInitializer(),
props::BetaInitializer(), props::GammaInitializer(),
props::Momentum(), props::Axis(), props::WeightDecay(),
props::BiasDecay()) {
wt_idx.fill(std::numeric_limits<unsigned>::max());
}

@@ -62,10 +62,10 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument)
<< "Only one input is allowed for batch normalization layer";

auto &bnparams_mu = std::get<props::BNPARAMS_MU_INIT>(bn_props);
auto &bnparams_var = std::get<props::BNPARAMS_VAR_INIT>(bn_props);
auto &bnparams_beta = std::get<props::BNPARAMS_BETA_INIT>(bn_props);
auto &bnparams_gamma = std::get<props::BNPARAMS_GAMMA_INIT>(bn_props);
auto &bnparams_mu = std::get<props::MuInitializer>(bn_props);
auto &bnparams_var = std::get<props::VarInitializer>(bn_props);
auto &bnparams_beta = std::get<props::BetaInitializer>(bn_props);
auto &bnparams_gamma = std::get<props::GammaInitializer>(bn_props);
auto &weight_decay = std::get<props::WeightDecay>(bn_props);
auto &bias_decay = std::get<props::BiasDecay>(bn_props);

6 changes: 3 additions & 3 deletions nntrainer/layers/bn_layer.h
Original file line number Diff line number Diff line change
@@ -126,9 +126,9 @@ class BatchNormalizationLayer : public Layer {
std::vector<unsigned int> axes_to_reduce; /**< target axes to reduce */
std::array<unsigned int, 11>
wt_idx; /**< indices of the weights and tensors */
std::tuple<props::Epsilon, props::BNPARAMS_MU_INIT, props::BNPARAMS_VAR_INIT,
props::BNPARAMS_BETA_INIT, props::BNPARAMS_GAMMA_INIT,
props::Momentum, props::Axis, props::WeightDecay, props::BiasDecay>
std::tuple<props::Epsilon, props::MuInitializer, props::VarInitializer,
props::BetaInitializer, props::GammaInitializer, props::Momentum,
props::Axis, props::WeightDecay, props::BiasDecay>
bn_props;
};

69 changes: 52 additions & 17 deletions nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp
Original file line number Diff line number Diff line change
@@ -91,13 +91,12 @@ static constexpr size_t SINGLE_INOUT_IDX = 0;

enum RMSParams { gamma };

RMSNormLayerCl::RMSNormLayerCl() : LayerImpl() { wt_idx.fill(0); }
RMSNormLayerCl::RMSNormLayerCl() : LayerImplCl() { wt_idx.fill(0); }

void RMSNormLayerCl::finalize(InitLayerContext &context) {
std::vector<TensorDim> dim = context.getInputDimensions();
context.setOutputDimensions(dim);
auto &rmsparams_gamma =
std::get<props::RMS_NORM_GAMMA_INIT_GPU>(rmsnorm_props);
auto &rmsparams_gamma = std::get<props::GammaInitializer>(rmsnorm_props);

TensorDim gamma_dim(
1, 1, 1, dim[0].width(),
@@ -123,9 +122,6 @@ void RMSNormLayerCl::forwarding(RunLayerContext &context, bool training) {
}
}

opencl::Kernel RMSNormLayerCl::kernel_rmsnorm;
opencl::Kernel RMSNormLayerCl::kernel_rmsnorm_fp16;

void RMSNormLayerCl::rmsnormProcess(Tensor const &input, Tensor &result,
Tensor const &gamma, const float epsilon) {
bool ret = false;
@@ -138,11 +134,8 @@ void RMSNormLayerCl::rmsnormProcess(Tensor const &input, Tensor &result,
int w = input.width();

do {
ClContext::SharedPtrClKernel kernel_rmsnorm_ptr =
cl_context_ref.registerClKernel(rmsnorm_cl_kernel_, "rmsnorm_cl");
if (!kernel_rmsnorm_ptr) {
break;
}

auto kernel_rmsnorm_ptr = layer_kernel_ptrs[Kernels::RMSNORM_CL];

opencl::Buffer inputbuf(cl_context_ref.context_inst_, dim1 * sizeof(float),
true, nullptr);
@@ -219,6 +212,7 @@ void RMSNormLayerCl::rmsnormProcess(Tensor const &input, Tensor &result,
} while (false);
}

#ifdef ENABLE_FP16
void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result,
Tensor const &gamma,
const float epsilon) {
@@ -232,12 +226,8 @@ void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result,
int h = input.height();
int w = input.width();
do {
ClContext::SharedPtrClKernel kernel_rmsnorm_ptr =
cl_context_ref.registerClKernel(rmsnorm_cl_kernel_fp16_,
"rmsnorm_cl_fp16");
if (!kernel_rmsnorm_ptr) {
break;
}
auto kernel_rmsnorm_ptr = layer_kernel_ptrs[Kernels::RMSNORM_CL_FP16];

opencl::Buffer inputbuf(cl_context_ref.context_inst_,
dim1 * sizeof(cl_half), true, nullptr);

@@ -308,6 +298,7 @@ void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result,
}
} while (false);
}
#endif

void RMSNormLayerCl::incremental_forwarding(nntrainer::RunLayerContext &context,
unsigned int from, unsigned int to,
@@ -339,7 +330,11 @@ void RMSNormLayerCl::incremental_forwarding(nntrainer::RunLayerContext &context,
if (in_step.getDataType() == ml::train::TensorDim::DataType::FP32) {
rmsnormProcess(in, out, gamma, epsilon);
} else {
#ifdef ENABLE_FP16
rmsnormProcess_fp16(in, out, gamma, epsilon);
#else
throw std::runtime_error("enable-fp16 is not enabled");
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it might be good to throw an error when fp16 is not enabled!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the comment. I updated it :)

}
}

@@ -362,4 +357,44 @@ void RMSNormLayerCl::setProperty(const std::vector<std::string> &values) {
LayerImpl::setProperty(remain_props);
}

bool RMSNormLayerCl::registerClKernels() {

// check if already registered
if (!layer_kernel_ptrs.empty()) {
ml_loge("kernels for concat layer are already registered.");
return false;
}

do {

ClContext::SharedPtrClKernel kernel_rmsnorm_ptr = nullptr;

kernel_rmsnorm_ptr =
cl_context_ref.registerClKernel(rmsnorm_cl_kernel_, "rmsnorm_cl");
if (!kernel_rmsnorm_ptr) {
ml_loge("OpenCL Error: Fail to register rmsnorm_cl kernel");
break;
}
layer_kernel_ptrs.emplace_back(kernel_rmsnorm_ptr);

#ifdef ENABLE_FP16
kernel_rmsnorm_ptr = cl_context_ref.registerClKernel(
rmsnorm_cl_kernel_fp16_, "rmsnorm_cl_fp16");
if (!kernel_rmsnorm_ptr) {
ml_loge("OpenCL Error: Fail to register rmsnorm_cl_fp16 kernel");
break;
}
layer_kernel_ptrs.emplace_back(kernel_rmsnorm_ptr);
#endif

return true;

} while (false);

// clear all registered kernels if any error occurs during registration
layer_kernel_ptrs.clear();

return false;
}

} // namespace nntrainer
48 changes: 17 additions & 31 deletions nntrainer/layers/cl_layers/rmsnorm_layer_cl.h
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
#ifdef __cplusplus

#include <common_properties.h>
#include <layer_impl.h>
#include <layer_impl_cl.h>
#include <nntrainer_log.h>

#include <cl_context.h>
@@ -25,36 +25,11 @@

namespace nntrainer {

namespace props {

/**
* @brief RMS_NORM_GAMMA_INIT_GPU Initialization Enumeration Information
*
*/
class RMS_NORM_GAMMA_INIT_GPU final
: public ::nntrainer::EnumProperty<::nntrainer::props::InitializerInfo> {
public:
/**
* @brief Construct a RMS_NORM_GAMMA_INIT object
*/
RMS_NORM_GAMMA_INIT_GPU(
::nntrainer::Initializer value = ::nntrainer::Initializer::ONES) {
set(value);
};
using prop_tag = enum_class_prop_tag;
static constexpr const char *key = "gamma_initializer";
};
}; // namespace props

/**
* @class RMSNormLayer
* @brief RMS Norm layer
*/

class RMSNormLayerCl : public LayerImpl {

private:
inline static ClContext cl_context_ref;
class RMSNormLayerCl : public LayerImplCl {

public:
/**
@@ -118,9 +93,6 @@ class RMSNormLayerCl : public LayerImpl {
*/
const std::string getType() const override { return RMSNormLayerCl::type; };

static opencl::Kernel kernel_rmsnorm;
static opencl::Kernel kernel_rmsnorm_fp16;

/**
* @brief Process data and dimensions for rms norm operation
* @param[in] input Tensor
@@ -153,12 +125,26 @@ class RMSNormLayerCl : public LayerImpl {
*/
void setProperty(const std::vector<std::string> &values) override;

/**
* @brief registerClKernels
*/
static bool registerClKernels();

inline static const std::string type = "rmsnorm";

private:
std::array<unsigned int, 1> wt_idx;
std::tuple<props::RMS_NORM_GAMMA_INIT_GPU, props::Epsilon>

std::tuple<props::GammaInitializer, props::Epsilon>
rmsnorm_props; /**< rmsnorm layer properties */

inline static std::vector<ClContext::SharedPtrClKernel>
layer_kernel_ptrs; /**< kernel list relevant with this layer */

enum Kernels {
RMSNORM_CL,
RMSNORM_CL_FP16,
};
};
} // namespace nntrainer

8 changes: 4 additions & 4 deletions nntrainer/layers/common_properties.cpp
Original file line number Diff line number Diff line change
@@ -314,13 +314,13 @@ WeightInitializer::WeightInitializer(Initializer value) { set(value); }

BiasInitializer::BiasInitializer(Initializer value) { set(value); }

BNPARAMS_MU_INIT::BNPARAMS_MU_INIT(Initializer value) { set(value); }
MuInitializer::MuInitializer(Initializer value) { set(value); }

BNPARAMS_VAR_INIT::BNPARAMS_VAR_INIT(Initializer value) { set(value); }
VarInitializer::VarInitializer(Initializer value) { set(value); }

BNPARAMS_GAMMA_INIT::BNPARAMS_GAMMA_INIT(Initializer value) { set(value); }
GammaInitializer::GammaInitializer(Initializer value) { set(value); }

BNPARAMS_BETA_INIT::BNPARAMS_BETA_INIT(Initializer value) { set(value); }
BetaInitializer::BetaInitializer(Initializer value) { set(value); }

BasicRegularizer::BasicRegularizer(nntrainer::WeightRegularizer value) {
set(value);
32 changes: 16 additions & 16 deletions nntrainer/layers/common_properties.h
Original file line number Diff line number Diff line change
@@ -1020,57 +1020,57 @@ class BiasInitializer final : public EnumProperty<InitializerInfo> {
};

/**
* @brief BNPARAMS_MU_INIT Initialization Enumeration Information
* @brief MuInitializer Initialization Enumeration Information
*
*/
class BNPARAMS_MU_INIT final : public EnumProperty<InitializerInfo> {
class MuInitializer final : public EnumProperty<InitializerInfo> {
public:
/**
* @brief Construct a BNPARAMS_MU_INIT object
* @brief Construct a MuInitializer object
*/
BNPARAMS_MU_INIT(Initializer value = Initializer::ZEROS);
MuInitializer(Initializer value = Initializer::ZEROS);
using prop_tag = enum_class_prop_tag;
static constexpr const char *key = "moving_mean_initializer";
};

/**
* @brief BNPARAMS_VAR_INIT Initialization Enumeration Information
* @brief VarInitializer Initialization Enumeration Information
*
*/
class BNPARAMS_VAR_INIT final : public EnumProperty<InitializerInfo> {
class VarInitializer final : public EnumProperty<InitializerInfo> {
public:
/**
* @brief Construct a BNPARAMS_VAR_INIT object
* @brief Construct a VarInitializer object
*/
BNPARAMS_VAR_INIT(Initializer value = Initializer::ONES);
VarInitializer(Initializer value = Initializer::ONES);
using prop_tag = enum_class_prop_tag;
static constexpr const char *key = "moving_variance_initializer";
};

/**
* @brief BNPARAMS_GAMMA_INIT Initialization Enumeration Information
* @brief GammaInitializer Initialization Enumeration Information
*
*/
class BNPARAMS_GAMMA_INIT final : public EnumProperty<InitializerInfo> {
class GammaInitializer final : public EnumProperty<InitializerInfo> {
public:
/**
* @brief Construct a BNPARAMS_GAMMA_INIT object
* @brief Construct a GammaInitializer object
*/
BNPARAMS_GAMMA_INIT(Initializer value = Initializer::ONES);
GammaInitializer(Initializer value = Initializer::ONES);
using prop_tag = enum_class_prop_tag;
static constexpr const char *key = "gamma_initializer";
};

/**
* @brief BNPARAMS_BETA_INIT Initialization Enumeration Information
* @brief BetaInitializer Initialization Enumeration Information
*
*/
class BNPARAMS_BETA_INIT final : public EnumProperty<InitializerInfo> {
class BetaInitializer final : public EnumProperty<InitializerInfo> {
public:
/**
* @brief Construct a BNPARAMS_BETA_INIT object
* @brief Construct a BetaInitializer object
*/
BNPARAMS_BETA_INIT(Initializer value = Initializer::ZEROS);
BetaInitializer(Initializer value = Initializer::ZEROS);
using prop_tag = enum_class_prop_tag;
static constexpr const char *key = "beta_initializer";
};
10 changes: 5 additions & 5 deletions nntrainer/layers/layer_normalization_layer.cpp
Original file line number Diff line number Diff line change
@@ -38,9 +38,9 @@ enum LNParams {

LayerNormalizationLayer::LayerNormalizationLayer() :
Layer(),
layer_normalization_props(
std::vector<props::Axis>(), props::Epsilon(), props::BNPARAMS_GAMMA_INIT(),
props::BNPARAMS_BETA_INIT(), props::WeightDecay(), props::BiasDecay()) {
layer_normalization_props(std::vector<props::Axis>(), props::Epsilon(),
props::GammaInitializer(), props::BetaInitializer(),
props::WeightDecay(), props::BiasDecay()) {
wt_idx.fill(std::numeric_limits<unsigned>::max());
}

@@ -51,9 +51,9 @@ void LayerNormalizationLayer::finalize(InitLayerContext &context) {
}

auto gamma_initializer =
std::get<props::BNPARAMS_GAMMA_INIT>(layer_normalization_props).get();
std::get<props::GammaInitializer>(layer_normalization_props).get();
auto beta_initializer =
std::get<props::BNPARAMS_BETA_INIT>(layer_normalization_props).get();
std::get<props::BetaInitializer>(layer_normalization_props).get();
auto weight_decay = std::get<props::WeightDecay>(layer_normalization_props);
auto bias_decay = std::get<props::BiasDecay>(layer_normalization_props);

5 changes: 2 additions & 3 deletions nntrainer/layers/layer_normalization_layer.h
Original file line number Diff line number Diff line change
@@ -124,9 +124,8 @@ class LayerNormalizationLayer : public Layer {
remain_axes; /**< remained axes (exclusive with normalize axes) */

std::array<unsigned int, 7> wt_idx;
std::tuple<std::vector<props::Axis>, props::Epsilon,
props::BNPARAMS_GAMMA_INIT, props::BNPARAMS_BETA_INIT,
props::WeightDecay, props::BiasDecay>
std::tuple<std::vector<props::Axis>, props::Epsilon, props::GammaInitializer,
props::BetaInitializer, props::WeightDecay, props::BiasDecay>
layer_normalization_props;
};

Loading