Skip to content

Commit

Permalink
[CUDA] Add multiclass objective for cuda_exp (#5473)
Browse files Browse the repository at this point in the history
* add multiclass objective for cuda_exp

* remove debug code

* add includes requested by lint checks

* fix compilation failure for cuda with cuda-9.0

* clean code
  • Loading branch information
shiyu1994 authored Sep 9, 2022
1 parent 2e9848c commit 3d4e08e
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 3 deletions.
41 changes: 41 additions & 0 deletions src/objective/cuda/cuda_multiclass_objective.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/

#ifdef USE_CUDA_EXP

#include "cuda_multiclass_objective.hpp"

#include <string>
#include <vector>

namespace LightGBM {

CUDAMulticlassSoftmax::CUDAMulticlassSoftmax(const Config& config): MulticlassSoftmax(config) {}

CUDAMulticlassSoftmax::CUDAMulticlassSoftmax(const std::vector<std::string>& strs): MulticlassSoftmax(strs) {}

CUDAMulticlassSoftmax::~CUDAMulticlassSoftmax() {}

void CUDAMulticlassSoftmax::Init(const Metadata& metadata, data_size_t num_data) {
MulticlassSoftmax::Init(metadata, num_data);
cuda_label_ = metadata.cuda_metadata()->cuda_label();
cuda_weights_ = metadata.cuda_metadata()->cuda_weights();
cuda_softmax_buffer_.Resize(static_cast<size_t>(num_data) * static_cast<size_t>(num_class_));
SynchronizeCUDADevice(__FILE__, __LINE__);
}

void CUDAMulticlassSoftmax::GetGradients(const double* score, score_t* gradients, score_t* hessians) const {
LaunchGetGradientsKernel(score, gradients, hessians);
SynchronizeCUDADevice(__FILE__, __LINE__);
}

void CUDAMulticlassSoftmax::ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const {
LaunchConvertOutputCUDAKernel(num_data, input, output);
}


} // namespace LightGBM

#endif // USE_CUDA_EXP
108 changes: 108 additions & 0 deletions src/objective/cuda/cuda_multiclass_objective.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/

#ifdef USE_CUDA_EXP

#include <algorithm>

#include "cuda_multiclass_objective.hpp"

namespace LightGBM {

__device__ void SoftmaxCUDA(double* softmax_buffer, int len) {
double wmax = softmax_buffer[0];
for (int i = 1; i < len; ++i) {
wmax = max(softmax_buffer[i], wmax);
}
double wsum = 0.0f;
for (int i = 0; i < len; ++i) {
softmax_buffer[i] = exp(softmax_buffer[i] - wmax);
wsum += softmax_buffer[i];
}
for (int i = 0; i < len; ++i) {
softmax_buffer[i] /= static_cast<double>(wsum);
}
}

template <bool USE_WEIGHT>
__global__ void GetGradientsKernel_MulticlassSoftmax(
const double* cuda_scores, const label_t* cuda_labels, const label_t* cuda_weights,
const double factor, const int num_class, const data_size_t num_data,
double* cuda_softmax_buffer, score_t* cuda_out_gradients, score_t* cuda_out_hessians) {
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
if (data_index < num_data) {
const data_size_t offset = data_index * num_class;
double* softmax_result = cuda_softmax_buffer + offset;
for (int k = 0; k < num_class; ++k) {
const double point_score = cuda_scores[k * num_data + data_index];
softmax_result[k] = cuda_scores[k * num_data + data_index];
}
SoftmaxCUDA(softmax_result, num_class);
if (!USE_WEIGHT) {
for (int k = 0; k < num_class; ++k) {
const double p = softmax_result[k];
size_t idx = static_cast<size_t>(num_data) * k + data_index;
if (static_cast<int>(cuda_labels[data_index]) == k) {
cuda_out_gradients[idx] = static_cast<score_t>(p - 1.0f);
} else {
cuda_out_gradients[idx] = static_cast<score_t>(p);
}
cuda_out_hessians[idx] = static_cast<score_t>(factor * p * (1.0f - p));
}
} else {
for (int k = 0; k < num_class; ++k) {
const double p = softmax_result[k];
const double weight = cuda_weights[data_index];
size_t idx = static_cast<size_t>(num_data) * k + data_index;
if (static_cast<int>(cuda_labels[data_index]) == k) {
cuda_out_gradients[idx] = static_cast<score_t>((p - 1.0f) * weight);
} else {
cuda_out_gradients[idx] = static_cast<score_t>(p * weight);
}
cuda_out_hessians[idx] = static_cast<score_t>((factor * p * (1.0f - p)) * weight);
}
}
}
}

void CUDAMulticlassSoftmax::LaunchGetGradientsKernel(const double* scores, score_t* gradients, score_t* hessians) const {
const int num_blocks = (num_data_ + GET_GRADIENTS_BLOCK_SIZE_MULTICLASS - 1) / GET_GRADIENTS_BLOCK_SIZE_MULTICLASS;
if (cuda_weights_ == nullptr) {
GetGradientsKernel_MulticlassSoftmax<false><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_MULTICLASS>>>(
scores, cuda_label_, cuda_weights_, factor_, num_class_, num_data_,
cuda_softmax_buffer_.RawData(), gradients, hessians);
} else {
GetGradientsKernel_MulticlassSoftmax<true><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_MULTICLASS>>>(
scores, cuda_label_, cuda_weights_, factor_, num_class_, num_data_,
cuda_softmax_buffer_.RawData(), gradients, hessians);
}
}

__global__ void ConvertOutputCUDAKernel_MulticlassSoftmax(
const int num_class, const data_size_t num_data, const double* input, double* cuda_softmax_buffer, double* output) {
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
if (data_index < num_data) {
const data_size_t offset = data_index * num_class;
double* cuda_softmax_buffer_ptr = cuda_softmax_buffer + offset;
for (int class_index = 0; class_index < num_class; ++class_index) {
cuda_softmax_buffer_ptr[class_index] = input[class_index * num_data + data_index];
}
SoftmaxCUDA(cuda_softmax_buffer_ptr, num_class);
for (int class_index = 0; class_index < num_class; ++class_index) {
output[class_index * num_data + data_index] = cuda_softmax_buffer_ptr[class_index];
}
}
}

void CUDAMulticlassSoftmax::LaunchConvertOutputCUDAKernel(
const data_size_t num_data, const double* input, double* output) const {
const int num_blocks = (num_data_ + GET_GRADIENTS_BLOCK_SIZE_MULTICLASS - 1) / GET_GRADIENTS_BLOCK_SIZE_MULTICLASS;
ConvertOutputCUDAKernel_MulticlassSoftmax<<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_MULTICLASS>>>(
num_class_, num_data, input, cuda_softmax_buffer_.RawData(), output);
}

} // namespace LightGBM

#endif // USE_CUDA_EXP
60 changes: 60 additions & 0 deletions src/objective/cuda/cuda_multiclass_objective.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifndef LIGHTGBM_OBJECTIVE_CUDA_CUDA_MULTICLASS_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_CUDA_CUDA_MULTICLASS_OBJECTIVE_HPP_

#ifdef USE_CUDA_EXP

#include <LightGBM/cuda/cuda_objective_function.hpp>

#include <string>
#include <vector>

#include "../multiclass_objective.hpp"

#define GET_GRADIENTS_BLOCK_SIZE_MULTICLASS (1024)

namespace LightGBM {

class CUDAMulticlassSoftmax: public CUDAObjectiveInterface, public MulticlassSoftmax {
public:
explicit CUDAMulticlassSoftmax(const Config& config);

explicit CUDAMulticlassSoftmax(const std::vector<std::string>& strs);

~CUDAMulticlassSoftmax();

void Init(const Metadata& metadata, data_size_t num_data) override;

void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override;

void ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const override;

std::function<void(data_size_t, const double*, double*)> GetCUDAConvertOutputFunc() const override {
return [this] (data_size_t num_data, const double* input, double* output) {
ConvertOutputCUDA(num_data, input, output);
};
}

bool IsCUDAObjective() const override { return true; }

private:
void LaunchGetGradientsKernel(const double* scores, score_t* gradients, score_t* hessians) const;

void LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const;

// CUDA memory, held by other objects
const label_t* cuda_label_;
const label_t* cuda_weights_;

// CUDA memory, held by this object
CUDAVector<double> cuda_softmax_buffer_;
};


} // namespace LightGBM

#endif // USE_CUDA_EXP
#endif // LIGHTGBM_OBJECTIVE_CUDA_CUDA_MULTICLASS_OBJECTIVE_HPP_
2 changes: 1 addition & 1 deletion src/objective/multiclass_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class MulticlassSoftmax: public ObjectiveFunction {
}
}

private:
protected:
double factor_;
/*! \brief Number of data */
data_size_t num_data_;
Expand Down
4 changes: 2 additions & 2 deletions src/objective/objective_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "xentropy_objective.hpp"

#include "cuda/cuda_binary_objective.hpp"
#include "cuda/cuda_multiclass_objective.hpp"
#include "cuda/cuda_rank_objective.hpp"
#include "cuda/cuda_regression_objective.hpp"

Expand Down Expand Up @@ -40,8 +41,7 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
} else if (type == std::string("rank_xendcg")) {
return new CUDARankXENDCG(config);
} else if (type == std::string("multiclass")) {
Log::Warning("Objective multiclass is not implemented in cuda_exp version. Fall back to boosting on CPU.");
return new MulticlassSoftmax(config);
return new CUDAMulticlassSoftmax(config);
} else if (type == std::string("multiclassova")) {
Log::Warning("Objective multiclassova is not implemented in cuda_exp version. Fall back to boosting on CPU.");
return new MulticlassOVA(config);
Expand Down

0 comments on commit 3d4e08e

Please sign in to comment.