diff --git a/cpp/bench/sg/svc.cu b/cpp/bench/sg/svc.cu index 8ddd8be441..21e046461f 100644 --- a/cpp/bench/sg/svc.cu +++ b/cpp/bench/sg/svc.cu @@ -36,17 +36,13 @@ struct SvcParams { BlobsParams blobs; raft::distance::kernels::KernelParams kernel; ML::SVM::SvmParameter svm_param; - ML::SVM::SvmModel model; }; template class SVC : public BlobsFixture { public: SVC(const std::string& name, const SvcParams& p) - : BlobsFixture(name, p.data, p.blobs), - kernel(p.kernel), - model(p.model), - svm_param(p.svm_param) + : BlobsFixture(name, p.data, p.blobs), kernel(p.kernel), svm_param(p.svm_param) { std::vector kernel_names{"linear", "poly", "rbf", "tanh"}; std::ostringstream oss; @@ -101,7 +97,6 @@ std::vector> getInputs() // SvmParameter{C, cache_size, max_iter, nochange_steps, tol, verbosity}) p.svm_param = ML::SVM::SvmParameter{1, 200, 100, 100, 1e-3, CUML_LEVEL_INFO, 0, ML::SVM::C_SVC}; - p.model = ML::SVM::SvmModel{0, 0, 0, nullptr, {}, nullptr, 0, nullptr}; std::vector rowcols = {{50000, 2, 2}, {2048, 100000, 2}, {50000, 1000, 2}}; diff --git a/cpp/bench/sg/svr.cu b/cpp/bench/sg/svr.cu index c061e53b1f..1fef40ee81 100644 --- a/cpp/bench/sg/svr.cu +++ b/cpp/bench/sg/svr.cu @@ -36,17 +36,13 @@ struct SvrParams { RegressionParams regression; raft::distance::kernels::KernelParams kernel; ML::SVM::SvmParameter svm_param; - ML::SVM::SvmModel* model; }; template class SVR : public RegressionFixture { public: SVR(const std::string& name, const SvrParams& p) - : RegressionFixture(name, p.data, p.regression), - kernel(p.kernel), - model(p.model), - svm_param(p.svm_param) + : RegressionFixture(name, p.data, p.regression), kernel(p.kernel), svm_param(p.svm_param) { std::vector kernel_names{"linear", "poly", "rbf", "tanh"}; std::ostringstream oss; @@ -69,16 +65,16 @@ class SVR : public RegressionFixture { this->data.y.data(), this->svm_param, this->kernel, - *(this->model)); + this->model); this->handle->sync_stream(this->stream); - ML::SVM::svmFreeBuffers(*this->handle, *(this->model)); + ML::SVM::svmFreeBuffers(*this->handle, this->model); }); } private: raft::distance::kernels::KernelParams kernel; ML::SVM::SvmParameter svm_param; - ML::SVM::SvmModel* model; + ML::SVM::SvmModel model; }; template @@ -103,7 +99,6 @@ std::vector> getInputs() // epsilon, svmType}) p.svm_param = ML::SVM::SvmParameter{1, 200, 200, 100, 1e-3, CUML_LEVEL_INFO, 0.1, ML::SVM::EPSILON_SVR}; - p.model = new ML::SVM::SvmModel{0, 0, 0, 0}; std::vector rowcols = {{50000, 2, 2}, {1024, 10000, 10}, {3000, 200, 200}}; diff --git a/cpp/include/cuml/svm/svc.hpp b/cpp/include/cuml/svm/svc.hpp index 426a049483..aa98463a8c 100644 --- a/cpp/include/cuml/svm/svc.hpp +++ b/cpp/include/cuml/svm/svc.hpp @@ -219,7 +219,7 @@ class SVC { raft::distance::kernels::KernelParams kernel_params; SvmParameter param; - SvmModel model; + SvmModelContainer model_container; /** * @brief Constructs a support vector classifier * @param handle cuML handle diff --git a/cpp/include/cuml/svm/svm_model.h b/cpp/include/cuml/svm/svm_model.h index 3237fbec3e..d65ced2a21 100644 --- a/cpp/include/cuml/svm/svm_model.h +++ b/cpp/include/cuml/svm/svm_model.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,16 +15,17 @@ */ #pragma once +#include + namespace ML { namespace SVM { // Contains array(s) for matrix storage -template struct SupportStorage { - int nnz = -1; - int* indptr = nullptr; - int* indices = nullptr; - math_t* data = nullptr; + int nnz = -1; + rmm::device_buffer* indptr; + rmm::device_buffer* indices; + rmm::device_buffer* data; }; /** @@ -39,17 +40,50 @@ struct SvmModel { //! Non-zero dual coefficients ( dual_coef[i] = \f$ y_i \alpha_i \f$). //! Size [n_support]. - math_t* dual_coefs; + rmm::device_buffer* dual_coefs; //! Support vector storage - can contain either CSR or dense - SupportStorage support_matrix; + SupportStorage support_matrix; //! Indices (from the training set) of the support vectors, size [n_support]. - int* support_idx; + rmm::device_buffer* support_idx; int n_classes; //!< Number of classes found in the input labels //! Device pointer for the unique classes. Size [n_classes] - math_t* unique_labels; + rmm::device_buffer* unique_labels; +}; + +/** + * Helper container that allows a SvmModel+buffer construction on the stack + */ +template +struct SvmModelContainer { + SvmModelContainer() + : dual_coef_bf(), + support_idx_bf(), + unique_labels_bf(), + support_matrix_indptr_bf(), + support_matrix_indices_bf(), + support_matrix_data_bf(), + model({0, + 0, + 0, + &dual_coef_bf, + SupportStorage{ + -1, &support_matrix_indptr_bf, &support_matrix_indices_bf, &support_matrix_data_bf}, + &support_idx_bf, + 0, + &unique_labels_bf}) + { + } + + rmm::device_buffer dual_coef_bf; + rmm::device_buffer support_idx_bf; + rmm::device_buffer unique_labels_bf; + rmm::device_buffer support_matrix_indptr_bf; + rmm::device_buffer support_matrix_indices_bf; + rmm::device_buffer support_matrix_data_bf; + SvmModel model; }; }; // namespace SVM diff --git a/cpp/src/svm/results.cuh b/cpp/src/svm/results.cuh index f33e8c4552..25ab974f24 100644 --- a/cpp/src/svm/results.cuh +++ b/cpp/src/svm/results.cuh @@ -117,23 +117,17 @@ class Results { */ void Get(const math_t* alpha, const math_t* f, - math_t** dual_coefs, - int* n_support, - int** idx, - SupportStorage* support_matrix, - math_t* b) + rmm::device_buffer& dual_coefs, + int& n_support, + rmm::device_buffer& idx, + SupportStorage& support_matrix, + math_t& b) { CombineCoefs(alpha, val_tmp.data()); GetDualCoefs(val_tmp.data(), dual_coefs, n_support); - *b = CalcB(alpha, f, *n_support); - if (*n_support > 0) { - *idx = GetSupportVectorIndices(val_tmp.data(), *n_support); - *support_matrix = CollectSupportVectorMatrix(*idx, *n_support); - } else { - *dual_coefs = nullptr; - *idx = nullptr; - *support_matrix = {}; - } + b = CalcB(alpha, f, n_support); + GetSupportVectorIndices(idx, val_tmp.data(), n_support); + CollectSupportVectorMatrix(support_matrix, idx, n_support); // Make sure that all pending GPU calculations finished before we return handle.sync_stream(stream); } @@ -141,32 +135,38 @@ class Results { /** * Collect support vectors into a matrix storage * + * @param [out] support_matrix containing the support vectors, size [n_suppor*n_cols] * @param [in] idx indices of support vectors, size [n_support] * @param [in] n_support number of support vectors - * @return pointer to a newly allocated device buffer that stores the support - * vectors, size [n_suppor*n_cols] */ - SupportStorage CollectSupportVectorMatrix(const int* idx, int n_support) + void CollectSupportVectorMatrix(SupportStorage& support_matrix, + rmm::device_buffer& idx, + int n_support) { - SupportStorage support_matrix; // allow ~1GB dense support matrix if (isDenseType() || ((size_t)n_support * n_cols * sizeof(math_t) < (1 << 30))) { - support_matrix.data = (math_t*)rmm_alloc.allocate_async( - n_support * n_cols * sizeof(math_t), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); - ML::SVM::extractRows(matrix, support_matrix.data, idx, n_support, handle); + support_matrix.nnz = -1; + support_matrix.indptr->resize(0, stream); + support_matrix.indices->resize(0, stream); + support_matrix.data->resize(n_support * n_cols * sizeof(math_t), stream); + if (n_support > 0) { + ML::SVM::extractRows(matrix, + reinterpret_cast(support_matrix.data->data()), + reinterpret_cast(idx.data()), + n_support, + handle); + } } else { ML::SVM::extractRows(matrix, - &(support_matrix.indptr), - &(support_matrix.indices), - &(support_matrix.data), + *(support_matrix.indptr), + *(support_matrix.indices), + *(support_matrix.data), &(support_matrix.nnz), - idx, + reinterpret_cast(idx.data()), n_support, handle); } - - return support_matrix; } /** @@ -205,14 +205,13 @@ class Results { * unallocated on entry, on exit size [n_support] * @param [out] n_support number of support vectors */ - void GetDualCoefs(const math_t* val_tmp, math_t** dual_coefs, int* n_support) + void GetDualCoefs(const math_t* val_tmp, rmm::device_buffer& dual_coefs, int& n_support) { // Return only the non-zero coefficients auto select_op = [] __device__(math_t a) { return 0 != a; }; - *n_support = SelectByCoef(val_tmp, n_rows, val_tmp, select_op, val_selected.data()); - *dual_coefs = (math_t*)rmm_alloc.allocate_async( - *n_support * sizeof(math_t), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); - raft::copy(*dual_coefs, val_selected.data(), *n_support, stream); + n_support = SelectByCoef(val_tmp, n_rows, val_tmp, select_op, val_selected.data()); + dual_coefs.resize(n_support * sizeof(math_t), stream); + raft::copy((math_t*)dual_coefs.data(), val_selected.data(), n_support, stream); handle.sync_stream(stream); } @@ -220,18 +219,20 @@ class Results { * Flag support vectors and also collect their indices. * Support vectors are the vectors where alpha > 0. * + * @param [out] idx the training set indices of the support vectors, size [n_support] * @param [in] coef dual coefficients, size [n_rows] * @param [in] n_support number of support vectors - * @return indices of the support vectors, size [n_support] */ - int* GetSupportVectorIndices(const math_t* coef, int n_support) + void GetSupportVectorIndices(rmm::device_buffer& idx, const math_t* coef, int n_support) { - auto select_op = [] __device__(math_t a) -> bool { return 0 != a; }; - SelectByCoef(coef, n_rows, f_idx.data(), select_op, idx_selected.data()); - int* idx = (int*)rmm_alloc.allocate_async( - n_support * sizeof(int), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); - raft::copy(idx, idx_selected.data(), n_support, stream); - return idx; + if (n_support > 0) { + auto select_op = [] __device__(math_t a) -> bool { return 0 != a; }; + SelectByCoef(coef, n_rows, f_idx.data(), select_op, idx_selected.data()); + idx.resize(n_support * sizeof(int), stream); + raft::copy((int*)idx.data(), idx_selected.data(), n_support, stream); + } else { + idx.resize(0, stream); + } } /** diff --git a/cpp/src/svm/smosolver.cuh b/cpp/src/svm/smosolver.cuh index 384c6c236b..933b1042f6 100644 --- a/cpp/src/svm/smosolver.cuh +++ b/cpp/src/svm/smosolver.cuh @@ -103,11 +103,11 @@ void SmoSolver::Solve(MatrixViewType matrix, int n_cols, math_t* y, const math_t* sample_weight, - math_t** dual_coefs, - int* n_support, - SupportStorage* support_matrix, - int** idx, - math_t* b, + rmm::device_buffer& dual_coefs, + int& n_support, + SupportStorage& support_matrix, + rmm::device_buffer& idx, + math_t& b, int max_outer_iter, int max_inner_iter) { diff --git a/cpp/src/svm/smosolver.h b/cpp/src/svm/smosolver.h index d2355d68a5..c3c6df3216 100644 --- a/cpp/src/svm/smosolver.h +++ b/cpp/src/svm/smosolver.h @@ -124,11 +124,11 @@ class SmoSolver { int n_cols, math_t* y, const math_t* sample_weight, - math_t** dual_coefs, - int* n_support, - SupportStorage* support_matrix, - int** idx, - math_t* b, + rmm::device_buffer& dual_coefs, + int& n_support, + SupportStorage& support_matrix, + rmm::device_buffer& idx, + math_t& b, int max_outer_iter = -1, int max_inner_iter = 10000); diff --git a/cpp/src/svm/sparse_util.cuh b/cpp/src/svm/sparse_util.cuh index c4d0b277e9..a53bb94f7f 100644 --- a/cpp/src/svm/sparse_util.cuh +++ b/cpp/src/svm/sparse_util.cuh @@ -717,9 +717,9 @@ void extractRows(raft::device_csr_matrix_view matrix_in, */ template void extractRows(raft::device_matrix_view matrix_in, - int** indptr_out, - int** indices_out, - math_t** data_out, + rmm::device_buffer& indptr_out, + rmm::device_buffer& indices_out, + rmm::device_buffer& data_out, int* nnz, const int* row_indices, int num_indices, @@ -734,8 +734,6 @@ void extractRows(raft::device_matrix_view matrix_in * This is the specialized version for * 'CSR -> CSR (raw pointers)' * - * Warning: this specialization will allocate the the required arrays in device memory. - * * @param [in] matrix_in matrix input in CSR [i, j] * @param [out] indptr_out row index pointer of CSR output [num_indices + 1] * @param [out] indices_out column indices of CSR output [nnz = indptr_out[num_indices + 1]] @@ -747,9 +745,9 @@ void extractRows(raft::device_matrix_view matrix_in */ template void extractRows(raft::device_csr_matrix_view matrix_in, - int** indptr_out, - int** indices_out, - math_t** data_out, + rmm::device_buffer& indptr_out, + rmm::device_buffer& indices_out, + rmm::device_buffer& data_out, int* nnz, const int* row_indices, int num_indices, @@ -762,20 +760,26 @@ void extractRows(raft::device_csr_matrix_view matrix_in, math_t* data_in = matrix_in.get_elements().data(); // allocate indptr - auto* rmm_alloc = rmm::mr::get_current_device_resource(); - *indptr_out = (int*)rmm_alloc->allocate((num_indices + 1) * sizeof(int), stream); + indptr_out.resize((num_indices + 1) * sizeof(int), stream); - *nnz = computeIndptrForSubset(indptr_in, *indptr_out, row_indices, num_indices, stream); + *nnz = + computeIndptrForSubset(indptr_in, (int*)indptr_out.data(), row_indices, num_indices, stream); // allocate indices, data - *indices_out = (int*)rmm_alloc->allocate(*nnz * sizeof(int), stream); - *data_out = (math_t*)rmm_alloc->allocate(*nnz * sizeof(math_t), stream); + indices_out.resize(*nnz * sizeof(int), stream); + data_out.resize(*nnz * sizeof(math_t), stream); // copy with 1 warp per row for now, blocksize 256 const dim3 bs(32, 8, 1); const dim3 gs(raft::ceildiv(num_indices, (int)bs.y), 1, 1); - extractCSRRowsFromCSR<<>>( - *indptr_out, *indices_out, *data_out, indptr_in, indices_in, data_in, row_indices, num_indices); + extractCSRRowsFromCSR<<>>((int*)indptr_out.data(), + (int*)indices_out.data(), + (math_t*)data_out.data(), + indptr_in, + indices_in, + data_in, + row_indices, + num_indices); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/src/svm/svc.cu b/cpp/src/svm/svc.cu index 7c90f0214b..de9d20994b 100644 --- a/cpp/src/svm/svc.cu +++ b/cpp/src/svm/svc.cu @@ -144,40 +144,44 @@ SVC::SVC(raft::handle_t& handle, param(SvmParameter{C, cache_size, max_iter, nochange_steps, tol, verbosity}), kernel_params(kernel_params) { - model.n_support = 0; - model.dual_coefs = nullptr; - model.support_matrix = {}; - model.support_idx = nullptr; - model.unique_labels = nullptr; } template SVC::~SVC() { - svmFreeBuffers(handle, model); } template void SVC::fit( math_t* input, int n_rows, int n_cols, math_t* labels, const math_t* sample_weight) { - model.n_cols = n_cols; - if (model.dual_coefs) svmFreeBuffers(handle, model); - svcFit(handle, input, n_rows, n_cols, labels, param, kernel_params, model, sample_weight); + model_container.model.n_cols = n_cols; + svmFreeBuffers(handle, model_container.model); + svcFit(handle, + input, + n_rows, + n_cols, + labels, + param, + kernel_params, + model_container.model, + sample_weight); } template void SVC::predict(math_t* input, int n_rows, int n_cols, math_t* preds) { math_t buffer_size = param.cache_size; - svcPredict(handle, input, n_rows, n_cols, kernel_params, model, preds, buffer_size, true); + svcPredict( + handle, input, n_rows, n_cols, kernel_params, model_container.model, preds, buffer_size, true); } template void SVC::decisionFunction(math_t* input, int n_rows, int n_cols, math_t* preds) { math_t buffer_size = param.cache_size; - svcPredict(handle, input, n_rows, n_cols, kernel_params, model, preds, buffer_size, false); + svcPredict( + handle, input, n_rows, n_cols, kernel_params, model_container.model, preds, buffer_size, false); } // Instantiate templates for the shared library diff --git a/cpp/src/svm/svc_impl.cuh b/cpp/src/svm/svc_impl.cuh index 3bd27dc6e4..19c11fa068 100644 --- a/cpp/src/svm/svc_impl.cuh +++ b/cpp/src/svm/svc_impl.cuh @@ -72,10 +72,8 @@ void svcFitX(const raft::handle_t& handle, { rmm::device_uvector unique_labels(0, stream); model.n_classes = raft::label::getUniquelabels(unique_labels, labels, n_rows, stream); - rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource(); - model.unique_labels = (math_t*)rmm_alloc.allocate_async( - model.n_classes * sizeof(math_t), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); - raft::copy(model.unique_labels, unique_labels.data(), model.n_classes, stream); + model.unique_labels->resize(model.n_classes * sizeof(math_t), stream); + raft::copy((math_t*)model.unique_labels->data(), unique_labels.data(), model.n_classes, stream); handle_impl.sync_stream(stream); } @@ -83,7 +81,7 @@ void svcFitX(const raft::handle_t& handle, rmm::device_uvector y(n_rows, stream); raft::label::getOvrlabels( - labels, n_rows, model.unique_labels, model.n_classes, y.data(), 1, stream); + labels, n_rows, (math_t*)model.unique_labels->data(), model.n_classes, y.data(), 1, stream); raft::distance::kernels::GramMatrixBase* kernel = raft::distance::kernels::KernelFactory::create(kernel_params); @@ -93,11 +91,11 @@ void svcFitX(const raft::handle_t& handle, n_cols, y.data(), sample_weight, - &(model.dual_coefs), - &(model.n_support), - &(model.support_matrix), - &(model.support_idx), - &(model.b), + *(model.dual_coefs), + model.n_support, + model.support_matrix, + *(model.support_idx), + model.b, param.max_iter); model.n_cols = n_cols; handle_impl.sync_stream(stream); @@ -193,28 +191,29 @@ void svcPredictX(const raft::handle_t& handle, rmm::device_uvector l2_support(0, stream); bool is_csr_input = !isDenseType(); - bool is_csr_support = model.support_matrix.data != nullptr && model.support_matrix.nnz >= 0; - bool is_dense_support = model.support_matrix.data != nullptr && !is_csr_support; + bool is_csr_support = model.support_matrix.data->size() > 0 && model.support_matrix.nnz >= 0; + bool is_dense_support = model.support_matrix.data->size() > 0 && !is_csr_support; // Unfortunately we need runtime support for both types raft::device_matrix_view dense_support_matrix_view; if (is_dense_support) { dense_support_matrix_view = raft::make_device_strided_matrix_view( - model.support_matrix.data, model.n_support, n_cols, 0); + (math_t*)model.support_matrix.data->data(), model.n_support, n_cols, 0); } auto csr_structure_view = is_csr_support - ? raft::make_device_compressed_structure_view(model.support_matrix.indptr, - model.support_matrix.indices, - model.n_support, - n_cols, - model.support_matrix.nnz) + ? raft::make_device_compressed_structure_view( + (int*)model.support_matrix.indptr->data(), + (int*)model.support_matrix.indices->data(), + model.n_support, + n_cols, + model.support_matrix.nnz) : raft::make_device_compressed_structure_view(nullptr, nullptr, 0, 0, 0); auto csr_support_matrix_view = is_csr_support - ? raft::make_device_csr_matrix_view(model.support_matrix.data, - csr_structure_view) + ? raft::make_device_csr_matrix_view( + (math_t*)model.support_matrix.data->data(), csr_structure_view) : raft::make_device_csr_matrix_view(nullptr, csr_structure_view); bool transpose_kernel = is_csr_support && !is_csr_input; @@ -278,7 +277,7 @@ void svcPredictX(const raft::handle_t& handle, &one, K.data(), transpose_kernel ? model.n_support : n_batch, - model.dual_coefs, + (math_t*)model.dual_coefs->data(), 1, &null, y.data() + i, @@ -287,7 +286,7 @@ void svcPredictX(const raft::handle_t& handle, } // end of loop - math_t* labels = model.unique_labels; + math_t* labels = (math_t*)model.unique_labels->data(); math_t b = model.b; if (predict_class) { // Look up the label based on the value of the decision function: @@ -355,48 +354,24 @@ void svcPredictSparse(const raft::handle_t& handle, template void svmFreeBuffers(const raft::handle_t& handle, SvmModel& m) { - cudaStream_t stream = handle.get_stream(); - rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource(); - if (m.dual_coefs) - rmm_alloc.deallocate_async( - m.dual_coefs, m.n_support * sizeof(math_t), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); - if (m.support_idx) - rmm_alloc.deallocate_async( - m.support_idx, m.n_support * sizeof(int), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); - if (m.support_matrix.indptr) { - rmm_alloc.deallocate_async(m.support_matrix.indptr, - (m.n_support + 1) * sizeof(int), - rmm::CUDA_ALLOCATION_ALIGNMENT, - stream); - m.support_matrix.indptr = nullptr; - } - if (m.support_matrix.indices) { - rmm_alloc.deallocate_async(m.support_matrix.indices, - m.support_matrix.nnz * sizeof(int), - rmm::CUDA_ALLOCATION_ALIGNMENT, - stream); - m.support_matrix.indices = nullptr; - } - if (m.support_matrix.data) { - if (m.support_matrix.nnz == -1) { - rmm_alloc.deallocate_async(m.support_matrix.data, - m.n_support * m.n_cols * sizeof(math_t), - rmm::CUDA_ALLOCATION_ALIGNMENT, - stream); - } else { - rmm_alloc.deallocate_async(m.support_matrix.data, - m.support_matrix.nnz * sizeof(math_t), - rmm::CUDA_ALLOCATION_ALIGNMENT, - stream); - } - } + cudaStream_t stream = handle.get_stream(); + + m.n_support = 0; + m.n_cols = 0; + m.b = (math_t)0; + m.dual_coefs->resize(0, stream); + m.dual_coefs->shrink_to_fit(stream); + m.support_idx->resize(0, stream); + m.support_idx->shrink_to_fit(stream); + m.support_matrix.indptr->resize(0, stream); + m.support_matrix.indptr->shrink_to_fit(stream); + m.support_matrix.indices->resize(0, stream); + m.support_matrix.indices->shrink_to_fit(stream); + m.support_matrix.data->resize(0, stream); + m.support_matrix.data->shrink_to_fit(stream); m.support_matrix.nnz = -1; - if (m.unique_labels) - rmm_alloc.deallocate_async( - m.unique_labels, m.n_classes * sizeof(math_t), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); - m.dual_coefs = nullptr; - m.support_idx = nullptr; - m.unique_labels = nullptr; + m.unique_labels->resize(0, stream); + m.unique_labels->shrink_to_fit(stream); } }; // end namespace SVM diff --git a/cpp/src/svm/svm_api.cpp b/cpp/src/svm/svm_api.cpp index 2f6f2b6efc..5ef9438f24 100644 --- a/cpp/src/svm/svm_api.cpp +++ b/cpp/src/svm/svm_api.cpp @@ -64,9 +64,12 @@ cumlError_t cumlSpSvcFit(cumlHandle_t handle, ML::SVM::SvmModel model; + rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource(); + cumlError_t status; raft::handle_t* handle_ptr; std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle); + cudaStream_t stream = handle_ptr->get_stream(); if (status == CUML_SUCCESS) { try { ML::SVM::svcFit(*handle_ptr, @@ -78,13 +81,47 @@ cumlError_t cumlSpSvcFit(cumlHandle_t handle, kernel_param, model, static_cast(nullptr)); - *n_support = model.n_support; - *b = model.b; - *dual_coefs = model.dual_coefs; - *x_support = model.support_matrix.data; - *support_idx = model.support_idx; - *n_classes = model.n_classes; - *unique_labels = model.unique_labels; + *n_support = model.n_support; + *b = model.b; + *n_classes = model.n_classes; + if (model.dual_coefs->size() > 0) { + *dual_coefs = (float*)rmm_alloc.allocate_async( + model.dual_coefs->size(), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); + raft::copy( + *dual_coefs, reinterpret_cast(model.dual_coefs->data()), *n_support, stream); + } else { + *dual_coefs = nullptr; + } + if (model.support_matrix.data->size() > 0) { + *x_support = (float*)rmm_alloc.allocate_async( + model.support_matrix.data->size(), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); + raft::copy(*x_support, + reinterpret_cast(model.support_matrix.data->data()), + *n_support * n_cols, + stream); + } else { + *x_support = nullptr; + } + if (model.support_idx->size() > 0) { + *support_idx = (int*)rmm_alloc.allocate_async( + model.support_idx->size(), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); + raft::copy( + *support_idx, reinterpret_cast(model.support_idx->data()), *n_support, stream); + } else { + *support_idx = nullptr; + } + if (model.unique_labels->size() > 0) { + *unique_labels = (float*)rmm_alloc.allocate_async( + model.unique_labels->size(), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); + raft::copy(*unique_labels, + reinterpret_cast(model.unique_labels->data()), + *n_classes, + stream); + } else { + *unique_labels = nullptr; + } + handle_ptr->sync_stream(stream); + } // TODO: Implement this // catch (const MLCommon::Exception& e) @@ -138,9 +175,12 @@ cumlError_t cumlDpSvcFit(cumlHandle_t handle, ML::SVM::SvmModel model; + rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource(); + cumlError_t status; raft::handle_t* handle_ptr; std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle); + cudaStream_t stream = handle_ptr->get_stream(); if (status == CUML_SUCCESS) { try { ML::SVM::svcFit(*handle_ptr, @@ -152,13 +192,46 @@ cumlError_t cumlDpSvcFit(cumlHandle_t handle, kernel_param, model, static_cast(nullptr)); - *n_support = model.n_support; - *b = model.b; - *dual_coefs = model.dual_coefs; - *x_support = model.support_matrix.data; - *support_idx = model.support_idx; - *n_classes = model.n_classes; - *unique_labels = model.unique_labels; + *n_support = model.n_support; + *b = model.b; + *n_classes = model.n_classes; + if (model.dual_coefs->size() > 0) { + *dual_coefs = (double*)rmm_alloc.allocate_async( + model.dual_coefs->size(), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); + raft::copy( + *dual_coefs, reinterpret_cast(model.dual_coefs->data()), *n_support, stream); + } else { + *dual_coefs = nullptr; + } + if (model.support_matrix.data->size() > 0) { + *x_support = (double*)rmm_alloc.allocate_async( + model.support_matrix.data->size(), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); + raft::copy(*x_support, + reinterpret_cast(model.support_matrix.data->data()), + *n_support * n_cols, + stream); + } else { + *x_support = nullptr; + } + if (model.support_idx->size() > 0) { + *support_idx = (int*)rmm_alloc.allocate_async( + model.support_idx->size(), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); + raft::copy( + *support_idx, reinterpret_cast(model.support_idx->data()), *n_support, stream); + } else { + *support_idx = nullptr; + } + if (model.unique_labels->size() > 0) { + *unique_labels = (double*)rmm_alloc.allocate_async( + model.unique_labels->size(), rmm::CUDA_ALLOCATION_ALIGNMENT, stream); + raft::copy(*unique_labels, + reinterpret_cast(model.unique_labels->data()), + *n_classes, + stream); + } else { + *unique_labels = nullptr; + } + handle_ptr->sync_stream(stream); } // TODO: Implement this // catch (const MLCommon::Exception& e) @@ -191,6 +264,11 @@ cumlError_t cumlSpSvcPredict(cumlHandle_t handle, float buffer_size, int predict_class) { + cumlError_t status; + raft::handle_t* handle_ptr; + std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle); + cudaStream_t stream = handle_ptr->get_stream(); + raft::distance::kernels::KernelParams kernel_param; kernel_param.kernel = (raft::distance::kernels::KernelType)kernel; kernel_param.degree = degree; @@ -198,18 +276,26 @@ cumlError_t cumlSpSvcPredict(cumlHandle_t handle, kernel_param.coef0 = coef0; ML::SVM::SvmModel model; - model.n_support = n_support; - model.b = b; - model.dual_coefs = dual_coefs; + model.n_support = n_support; + model.b = b; + model.n_classes = n_classes; + if (n_support > 0) { + model.dual_coefs->resize(n_support * sizeof(float), stream); + raft::copy(reinterpret_cast(model.dual_coefs->data()), dual_coefs, n_support, stream); - model.support_matrix = {.data = x_support}; - model.support_idx = nullptr; - model.n_classes = n_classes; - model.unique_labels = unique_labels; + model.support_matrix.data->resize(n_support * n_cols * sizeof(float), stream); + raft::copy(reinterpret_cast(model.support_matrix.data->data()), + x_support, + n_support * n_cols, + stream); + } + + if (n_classes > 0) { + model.unique_labels->resize(n_classes * sizeof(float), stream); + raft::copy( + reinterpret_cast(model.unique_labels->data()), unique_labels, n_classes, stream); + } - cumlError_t status; - raft::handle_t* handle_ptr; - std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle); if (status == CUML_SUCCESS) { try { ML::SVM::svcPredict( @@ -246,6 +332,11 @@ cumlError_t cumlDpSvcPredict(cumlHandle_t handle, double buffer_size, int predict_class) { + cumlError_t status; + raft::handle_t* handle_ptr; + std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle); + cudaStream_t stream = handle_ptr->get_stream(); + raft::distance::kernels::KernelParams kernel_param; kernel_param.kernel = (raft::distance::kernels::KernelType)kernel; kernel_param.degree = degree; @@ -253,18 +344,26 @@ cumlError_t cumlDpSvcPredict(cumlHandle_t handle, kernel_param.coef0 = coef0; ML::SVM::SvmModel model; - model.n_support = n_support; - model.b = b; - model.dual_coefs = dual_coefs; + model.n_support = n_support; + model.b = b; + model.n_classes = n_classes; + if (n_support > 0) { + model.dual_coefs->resize(n_support * sizeof(double), stream); + raft::copy(reinterpret_cast(model.dual_coefs->data()), dual_coefs, n_support, stream); - model.support_matrix = {.data = x_support}; - model.support_idx = nullptr; - model.n_classes = n_classes; - model.unique_labels = unique_labels; + model.support_matrix.data->resize(n_support * n_cols * sizeof(double), stream); + raft::copy(reinterpret_cast(model.support_matrix.data->data()), + x_support, + n_support * n_cols, + stream); + } + + if (n_classes > 0) { + model.unique_labels->resize(n_classes * sizeof(double), stream); + raft::copy( + reinterpret_cast(model.unique_labels->data()), unique_labels, n_classes, stream); + } - cumlError_t status; - raft::handle_t* handle_ptr; - std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle); if (status == CUML_SUCCESS) { try { ML::SVM::svcPredict( diff --git a/cpp/src/svm/svr_impl.cuh b/cpp/src/svm/svr_impl.cuh index 1ff62a03c7..0abdd8f0db 100644 --- a/cpp/src/svm/svr_impl.cuh +++ b/cpp/src/svm/svr_impl.cuh @@ -71,11 +71,11 @@ void svrFitX(const raft::handle_t& handle, n_cols, y, sample_weight, - &(model.dual_coefs), - &(model.n_support), - &(model.support_matrix), - &(model.support_idx), - &(model.b), + *(model.dual_coefs), + model.n_support, + model.support_matrix, + *(model.support_idx), + model.b, param.max_iter); model.n_cols = n_cols; delete kernel; diff --git a/cpp/test/sg/svc_test.cu b/cpp/test/sg/svc_test.cu index 0caad107d5..7692c68d98 100644 --- a/cpp/test/sg/svc_test.cu +++ b/cpp/test/sg/svc_test.cu @@ -502,13 +502,9 @@ class GetResultsTest : public ::testing::Test { protected: void FreeDenseSupport() { - rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource(); - auto stream = this->handle.get_stream(); - rmm_alloc.deallocate_async(support_matrix.data, - n_coefs * n_cols * sizeof(math_t), - rmm::CUDA_ALLOCATION_ALIGNMENT, - stream); - support_matrix.data = nullptr; + model_container.model.support_matrix.data->resize(0, stream); + // this *really* deallocates + model_container.model.support_matrix.data->shrink_to_fit(stream); } void TestResults() @@ -525,38 +521,56 @@ class GetResultsTest : public ::testing::Test { rmm::device_uvector C_dev(n_rows, stream); init_C(C, C_dev.data(), n_rows, stream); + SvmModel& model = model_container.model; + auto dense_view = raft::make_device_strided_matrix_view( x_dev.data(), n_rows, n_cols, 0); Results> res( handle, dense_view, n_rows, n_cols, y_dev.data(), C_dev.data(), C_SVC); - res.Get(alpha_dev.data(), f_dev.data(), &dual_coefs, &n_coefs, &idx, &support_matrix, &b); + res.Get(alpha_dev.data(), + f_dev.data(), + *(model.dual_coefs), + model.n_support, + *(model.support_idx), + model.support_matrix, + model.b); - ASSERT_EQ(n_coefs, 7); + ASSERT_EQ(model.n_support, 7); math_t dual_coefs_exp[] = {-0.1, -0.2, -1.5, 0.2, 0.4, 1.5, 1.5}; - EXPECT_TRUE(devArrMatchHost( - dual_coefs_exp, dual_coefs, n_coefs, MLCommon::CompareApprox(1e-6f), stream)); + EXPECT_TRUE(devArrMatchHost(dual_coefs_exp, + (math_t*)model.dual_coefs->data(), + model.n_support, + MLCommon::CompareApprox(1e-6f), + stream)); int idx_exp[] = {2, 3, 4, 6, 7, 8, 9}; - EXPECT_TRUE(devArrMatchHost(idx_exp, idx, n_coefs, MLCommon::Compare(), stream)); + EXPECT_TRUE(devArrMatchHost( + idx_exp, (int*)model.support_idx->data(), model.n_support, MLCommon::Compare(), stream)); math_t x_support_exp[] = {3, 4, 5, 7, 8, 9, 10, 13, 14, 15, 17, 18, 19, 20}; EXPECT_TRUE(devArrMatchHost(x_support_exp, - support_matrix.data, - n_coefs * n_cols, + (math_t*)model.support_matrix.data->data(), + model.n_support * n_cols, MLCommon::CompareApprox(1e-6f), stream)); - EXPECT_FLOAT_EQ(b, -6.25f); + EXPECT_FLOAT_EQ(model.b, -6.25f); // Modify the test by setting all SVs bound, then b is calculated differently math_t alpha_host2[10] = {0, 0, 1.5, 1.5, 1.5, 0, 1.5, 1.5, 1.5, 1.5}; raft::update_device(alpha_dev.data(), alpha_host2, n_rows, stream); FreeDenseSupport(); - res.Get(alpha_dev.data(), f_dev.data(), &dual_coefs, &n_coefs, &idx, &support_matrix, &b); + res.Get(alpha_dev.data(), + f_dev.data(), + *(model.dual_coefs), + model.n_support, + *(model.support_idx), + model.support_matrix, + model.b); FreeDenseSupport(); - EXPECT_FLOAT_EQ(b, -5.5f); + EXPECT_FLOAT_EQ(model.b, -5.5f); } raft::handle_t handle; @@ -571,11 +585,7 @@ class GetResultsTest : public ::testing::Test { // l l l/u l/u u u l/u l/u l l math_t C = 1.5; - math_t* dual_coefs; - int n_coefs; - int* idx; - SupportStorage support_matrix; - math_t b; + SvmModelContainer model_container; }; TYPED_TEST_CASE(GetResultsTest, FloatTypes); @@ -791,7 +801,7 @@ struct svmTol { }; template -void checkResults(SvmModel model, +void checkResults(SvmModel& model, smoOutput expected, cudaStream_t stream, svmTol tol = svmTol{0.001, 0.99999, -1}) @@ -809,11 +819,13 @@ void checkResults(SvmModel model, } EXPECT_LE(abs(model.n_support - expected.n_support), tol.n_sv); if (dcoef_exp) { - EXPECT_TRUE(devArrMatchHost( - dcoef_exp, model.dual_coefs, model.n_support, MLCommon::CompareApprox(1e-3f))); + EXPECT_TRUE(devArrMatchHost(dcoef_exp, + (math_t*)model.dual_coefs->data(), + model.n_support, + MLCommon::CompareApprox(1e-3f))); } math_t* dual_coefs_host = new math_t[model.n_support]; - raft::update_host(dual_coefs_host, model.dual_coefs, model.n_support, stream); + raft::update_host(dual_coefs_host, (math_t*)model.dual_coefs->data(), model.n_support, stream); raft::interruptible::synchronize(stream); math_t ay = 0; for (int i = 0; i < model.n_support; i++) { @@ -823,9 +835,9 @@ void checkResults(SvmModel model, EXPECT_LT(raft::abs(ay), ay_tol); if (x_support_exp) { - EXPECT_TRUE(model.support_matrix.data != nullptr && model.support_matrix.nnz == -1); + EXPECT_TRUE(model.support_matrix.data->size() > 0 && model.support_matrix.nnz == -1); EXPECT_TRUE(devArrMatchHost(x_support_exp, - model.support_matrix.data, + (math_t*)model.support_matrix.data->data(), model.n_support * model.n_cols, MLCommon::CompareApprox(1e-6f), stream)); @@ -833,14 +845,16 @@ void checkResults(SvmModel model, if (idx_exp) { EXPECT_TRUE(devArrMatchHost( - idx_exp, model.support_idx, model.n_support, MLCommon::Compare(), stream)); + idx_exp, (int*)model.support_idx->data(), model.n_support, MLCommon::Compare(), stream)); } math_t* x_support_host = new math_t[model.n_support * model.n_cols]; if (model.n_support * model.n_cols > 0) { - EXPECT_TRUE(model.support_matrix.data != nullptr && model.support_matrix.nnz == -1); - raft::update_host( - x_support_host, model.support_matrix.data, model.n_support * model.n_cols, stream); + EXPECT_TRUE(model.support_matrix.data->size() > 0 && model.support_matrix.nnz == -1); + raft::update_host(x_support_host, + (math_t*)model.support_matrix.data->data(), + model.n_support * model.n_cols, + stream); } raft::interruptible::synchronize(stream); @@ -1109,7 +1123,9 @@ TYPED_TEST(SmoSolverTest, SmoSolveTest) GramMatrixBase* kernel = KernelFactory::create(p.kernel_params); SmoSolver smo(this->handle, param, p.kernel_params.kernel, kernel); { - SvmModel model1{0, this->n_cols, 0, nullptr, {}, nullptr, 0, nullptr}; + SvmModelContainer model_container1; + SvmModel& model1 = model_container1.model; + model1.n_cols = this->n_cols; auto dense_view = raft::make_device_strided_matrix_view( this->x_dev.data(), this->n_rows, this->n_cols, 0); @@ -1118,20 +1134,21 @@ TYPED_TEST(SmoSolverTest, SmoSolveTest) this->n_cols, this->y_dev.data(), nullptr, - &model1.dual_coefs, - &model1.n_support, - &model1.support_matrix, - &model1.support_idx, - &model1.b, + *model1.dual_coefs, + model1.n_support, + model1.support_matrix, + *model1.support_idx, + model1.b, p.max_iter, p.max_inner_iter); checkResults(model1, exp, stream); - svmFreeBuffers(this->handle, model1); } // also check sparse input { - SvmModel model2{0, this->n_cols, 0, nullptr, {}, nullptr, 0, nullptr}; + SvmModelContainer model_container2; + SvmModel& model2 = model_container2.model; + model2.n_cols = this->n_cols; auto csr_structure = raft::make_device_compressed_structure_view(this->x_dev_indptr.data(), this->x_dev_indices.data(), @@ -1144,15 +1161,14 @@ TYPED_TEST(SmoSolverTest, SmoSolveTest) this->n_cols, this->y_dev.data(), nullptr, - &model2.dual_coefs, - &model2.n_support, - &model2.support_matrix, - &model2.support_idx, - &model2.b, + *(model2.dual_coefs), + model2.n_support, + model2.support_matrix, + *(model2.support_idx), + model2.b, p.max_iter, p.max_inner_iter); checkResults(model2, exp, stream); - svmFreeBuffers(this->handle, model2); } } } @@ -1252,7 +1268,7 @@ TYPED_TEST(SmoSolverTest, SvcTest) } SVC svc(this->handle, p.C, p.tol, p.kernel_params); svc.fit(p.x_dev, p.n_rows, p.n_cols, p.y_dev, sample_weights); - checkResults(svc.model, toSmoOutput(exp), stream); + checkResults(svc.model_container.model, toSmoOutput(exp), stream); rmm::device_uvector y_pred(p.n_rows, stream); if (p.predict) { svc.predict(p.x_dev, p.n_rows, p.n_cols, y_pred.data()); @@ -1663,7 +1679,7 @@ TYPED_TEST(SmoSolverTest, DenseBatching) SvmParameter param = getDefaultSvmParameter(); param.max_iter = 2; - SvmModel model; + SvmModelContainer model_container; TypeParam* sample_weights = nullptr; svcFit(this->handle, dense_input.data(), @@ -1672,7 +1688,7 @@ TYPED_TEST(SmoSolverTest, DenseBatching) y.data(), param, input.kernel_params, - model, + model_container.model, sample_weights); // TODO predict with subset csr & dense @@ -1682,12 +1698,12 @@ TYPED_TEST(SmoSolverTest, DenseBatching) input.n_rows, input.n_cols, input.kernel_params, - model, + model_container.model, y_pred.data(), (TypeParam)200.0, false); - svmFreeBuffers(this->handle, model); + svmFreeBuffers(this->handle, model_container.model); } } } @@ -1734,7 +1750,7 @@ TYPED_TEST(SmoSolverTest, SparseBatching) SvmParameter param = getDefaultSvmParameter(); param.max_iter = 2; - SvmModel model; + SvmModelContainer model_container; TypeParam* sample_weights = nullptr; svcFitSparse(this->handle, csr_structure.get_indptr().data(), @@ -1746,7 +1762,7 @@ TYPED_TEST(SmoSolverTest, SparseBatching) y.data(), param, input.kernel_params, - model, + model_container.model, sample_weights); // predict with full input @@ -1759,7 +1775,7 @@ TYPED_TEST(SmoSolverTest, SparseBatching) csr_structure.get_n_cols(), csr_structure.get_nnz(), input.kernel_params, - model, + model_container.model, y_pred.data(), (TypeParam)200.0, false); @@ -1767,7 +1783,7 @@ TYPED_TEST(SmoSolverTest, SparseBatching) y.data(), y_pred.data(), input.n_rows, MLCommon::CompareApprox(1e-6), stream); // predict with subset csr & dense for all edge cases - if (model.support_matrix.nnz >= 0) { + if (model_container.model.support_matrix.nnz >= 0) { int n_extract = 100; rmm::device_uvector sequence(n_extract, stream); auto csr_subset = raft::make_device_csr_matrix( @@ -1794,7 +1810,7 @@ TYPED_TEST(SmoSolverTest, SparseBatching) csr_subset.structure_view().get_n_cols(), csr_subset.structure_view().get_nnz(), input.kernel_params, - model, + model_container.model, y_pred_csr.data(), (TypeParam)50.0, false); @@ -1803,7 +1819,7 @@ TYPED_TEST(SmoSolverTest, SparseBatching) n_extract, input.n_cols, input.kernel_params, - model, + model_container.model, y_pred_dense.data(), (TypeParam)50.0, false); @@ -1814,7 +1830,7 @@ TYPED_TEST(SmoSolverTest, SparseBatching) stream); } - svmFreeBuffers(this->handle, model); + svmFreeBuffers(this->handle, model_container.model); } } } @@ -1858,15 +1874,10 @@ class SvrTest : public ::testing::Test { raft::update_device(x_dev.data(), x_host, n_rows * n_cols, stream); raft::update_device(y_dev.data(), y_host, n_rows, stream); - model.n_support = 0; - model.dual_coefs = nullptr; - model.support_matrix = {}; - model.support_idx = nullptr; - model.n_classes = 0; - model.unique_labels = nullptr; + svmFreeBuffers(handle, model_container.model); } - void TearDown() override { svmFreeBuffers(handle, model); } + void TearDown() override { svmFreeBuffers(handle, model_container.model); } public: void TestSvrInit() @@ -1917,34 +1928,40 @@ class SvrTest : public ::testing::Test { x_dev.data(), n_rows, n_cols, 0); Results> res( handle, dense_view, n_rows, n_cols, yc.data(), C_dev.data(), EPSILON_SVR); + + SvmModel& model = model_container.model; + model.n_cols = n_cols; raft::update_device(alpha.data(), alpha_host, n_train, stream); raft::update_device(f.data(), f_exp, n_train, stream); res.Get(alpha.data(), f.data(), - &model.dual_coefs, - &model.n_support, - &model.support_idx, - &model.support_matrix, - &model.b); + *(model.dual_coefs), + model.n_support, + *(model.support_idx), + model.support_matrix, + model.b); ASSERT_EQ(model.n_support, 5); math_t dc_exp[] = {0.1, 0.3, -0.4, 0.9, -0.9}; - EXPECT_TRUE(devArrMatchHost( - dc_exp, model.dual_coefs, model.n_support, MLCommon::CompareApprox(1.0e-6), stream)); + EXPECT_TRUE(devArrMatchHost(dc_exp, + (math_t*)model.dual_coefs->data(), + model.n_support, + MLCommon::CompareApprox(1.0e-6), + stream)); EXPECT_TRUE(model.support_matrix.nnz == -1); math_t x_exp[] = {1, 2, 3, 5, 6}; EXPECT_TRUE(devArrMatchHost(x_exp, - model.support_matrix.data, + (math_t*)model.support_matrix.data->data(), model.n_support * n_cols, MLCommon::CompareApprox(1.0e-6), stream)); int idx_exp[] = {0, 1, 2, 4, 5}; EXPECT_TRUE(devArrMatchHost(idx_exp, - model.support_idx, + (int*)model.support_idx->data(), model.n_support, MLCommon::CompareApprox(1.0e-6), stream)); @@ -2040,16 +2057,16 @@ class SvrTest : public ::testing::Test { y_dev.data(), p.param, p.kernel, - model, + model_container.model, sample_weights); - checkResults(model, toSmoOutput(exp), stream); + checkResults(model_container.model, toSmoOutput(exp), stream); rmm::device_uvector preds(p.n_rows, stream); svcPredict(handle, x_dev.data(), p.n_rows, p.n_cols, p.kernel, - model, + model_container.model, preds.data(), (math_t)200.0, false); @@ -2070,7 +2087,7 @@ class SvrTest : public ::testing::Test { int n_train = 2 * n_rows; const int n_cols = 1; - SvmModel model; + SvmModelContainer model_container; rmm::device_uvector x_dev; rmm::device_uvector y_dev; rmm::device_uvector C_dev; diff --git a/python/cuml/cuml/svm/svc.pyx b/python/cuml/cuml/svm/svc.pyx index 566aa4b762..a0c4109137 100644 --- a/python/cuml/cuml/svm/svc.pyx +++ b/python/cuml/cuml/svm/svc.pyx @@ -43,6 +43,8 @@ from cuml.svm.svm_base import SVMBase from cuml.internals.import_utils import has_sklearn from cuml.internals.array_sparse import SparseCumlArray +from rmm._lib.device_buffer cimport device_buffer + if has_sklearn(): from cuml.multiclass import MulticlassClassifier from sklearn.calibration import CalibratedClassifierCV @@ -82,22 +84,22 @@ cdef extern from "cuml/svm/svm_parameter.h" namespace "ML::SVM": cdef extern from "cuml/svm/svm_model.h" namespace "ML::SVM": - cdef cppclass SupportStorage[math_t]: + cdef cppclass SupportStorage: int nnz - int* indptr - int* indices - math_t* data + device_buffer* indptr + device_buffer* indices + device_buffer* data cdef cppclass SvmModel[math_t]: # parameters of a fitted model int n_support int n_cols math_t b - math_t *dual_coefs - SupportStorage[math_t] support_matrix - int *support_idx + device_buffer* dual_coefs + SupportStorage support_matrix + device_buffer* support_idx int n_classes - math_t *unique_labels + device_buffer* unique_labels cdef extern from "cuml/svm/svc.hpp" namespace "ML::SVM" nogil: @@ -488,7 +490,6 @@ class SVC(SVMBase, Fit the model with X and y. """ - self.n_classes_ = self._get_num_classes(y) # we need to check whether input X is sparse @@ -547,6 +548,8 @@ class SVC(SVMBase, if self.dtype == np.float32: model_f = new SvmModel[float]() + self._model = model_f + self._init_model_buffers() if is_sparse: with cuda_interruptible(): with nogil: @@ -562,9 +565,10 @@ class SVC(SVMBase, deref(handle_), X_data, n_rows, n_cols, y_ptr, param, _kernel_params, deref(model_f), sample_weight_ptr) - self._model = model_f elif self.dtype == np.float64: model_d = new SvmModel[double]() + self._model = model_d + self._init_model_buffers() if is_sparse: with cuda_interruptible(): with nogil: @@ -580,7 +584,6 @@ class SVC(SVMBase, deref(handle_), X_data, n_rows, n_cols, y_ptr, param, _kernel_params, deref(model_d), sample_weight_ptr) - self._model = model_d else: raise TypeError('Input data type should be float32 or float64') diff --git a/python/cuml/cuml/svm/svm_base.pyx b/python/cuml/cuml/svm/svm_base.pyx index 1a478fbc9c..32fdd51bb1 100644 --- a/python/cuml/cuml/svm/svm_base.pyx +++ b/python/cuml/cuml/svm/svm_base.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -39,6 +39,7 @@ from cuml.internals.mixins import FMajorInputTagMixin from cuml.internals.array_sparse import SparseCumlArray, SparseCumlArrayInput from libcpp cimport bool +from rmm._lib.device_buffer cimport device_buffer, DeviceBuffer cdef extern from "raft/distance/distance_types.hpp" \ namespace "raft::distance::kernels": @@ -75,22 +76,22 @@ cdef extern from "cuml/svm/svm_parameter.h" namespace "ML::SVM": cdef extern from "cuml/svm/svm_model.h" namespace "ML::SVM": - cdef cppclass SupportStorage[math_t]: + cdef cppclass SupportStorage: int nnz - int* indptr - int* indices - math_t* data + device_buffer* indptr + device_buffer* indices + device_buffer* data cdef cppclass SvmModel[math_t]: # parameters of a fitted model int n_support int n_cols math_t b - math_t *dual_coefs - SupportStorage[math_t] support_matrix - int *support_idx + device_buffer* dual_coefs + SupportStorage support_matrix + device_buffer* support_idx int n_classes - math_t *unique_labels + device_buffer* unique_labels cdef extern from "cuml/svm/svc.hpp" namespace "ML::SVM": @@ -105,9 +106,6 @@ cdef extern from "cuml/svm/svc.hpp" namespace "ML::SVM": KernelParams &kernel_params, const SvmModel[math_t] &model, math_t *preds, math_t buffer_size, bool predict_class) except + - cdef void svmFreeBuffers[math_t](const handle_t &handle, - SvmModel[math_t] &m) except + - class SVMBase(Base, FMajorInputTagMixin): @@ -252,12 +250,19 @@ class SVMBase(Base, self._intercept_ = None self.n_support_ = None + # device buffers to back managed model storage + self.__dual_coef_buffer_ = DeviceBuffer() + self.__support_idx_buffer_ = DeviceBuffer() + self.__unique_labels__buffer_ = DeviceBuffer() + self.__support_indptr_buffer_ = DeviceBuffer() + self.__support_indices_buffer_ = DeviceBuffer() + self.__support_data_buffer_ = DeviceBuffer() + self._c_kernel = self._get_c_kernel(kernel) self._gamma_val = None # the actual numerical value used for training self.coef_ = None # value of the coef_ attribute, only for lin kernel self.dtype = None self._model = None # structure of the model parameters - self._freeSvmBuffers = False # whether to call the C++ lib for cleanup if (kernel == 'linear' or (kernel == 'poly' and degree == 1)) \ and not getattr(type(self), "_linear_kernel_warned", False): @@ -274,17 +279,12 @@ class SVMBase(Base, # deallocate model parameters cdef SvmModel[float] *model_f cdef SvmModel[double] *model_d - cdef handle_t* handle_ = self.handle.getHandle() if self._model is not None: if self.dtype == np.float32: model_f = self._model - if self._freeSvmBuffers: - svmFreeBuffers(handle_[0], model_f[0]) del model_f elif self.dtype == np.float64: model_d = self._model - if self._freeSvmBuffers: - svmFreeBuffers(handle_[0], model_d[0]) del model_d else: raise TypeError("Unknown type for SVC class") @@ -293,6 +293,14 @@ class SVMBase(Base, except AttributeError: pass + # re-init / clean all storages + self.__dual_coef_buffer_ = DeviceBuffer() + self.__support_idx_buffer_ = DeviceBuffer() + self.__unique_labels__buffer_ = DeviceBuffer() + self.__support_indptr_buffer_ = DeviceBuffer() + self.__support_indices_buffer_ = DeviceBuffer() + self.__support_data_buffer_ = DeviceBuffer() + self._model = None def _get_c_kernel(self, kernel): @@ -402,6 +410,32 @@ class SVMBase(Base, param.svmType = self.svmType return param + def _init_model_buffers(self): + + if self._model is None: + raise AttributeError("_init_model_buffers is only available after _model is set") + + cdef SvmModel[float] *model_f + cdef SvmModel[double] *model_d + + if self.dtype == np.float32: + model_f = self._model + model_f.dual_coefs = (self.__dual_coef_buffer_).c_obj.get() + model_f.support_idx = (self.__support_idx_buffer_).c_obj.get() + model_f.unique_labels = (self.__unique_labels__buffer_).c_obj.get() + model_f.support_matrix.indptr = (self.__support_indptr_buffer_).c_obj.get() + model_f.support_matrix.indices = (self.__support_indices_buffer_).c_obj.get() + model_f.support_matrix.data = (self.__support_data_buffer_).c_obj.get() + + else: + model_d = self._model + model_d.dual_coefs = (self.__dual_coef_buffer_).c_obj.get() + model_d.support_idx = (self.__support_idx_buffer_).c_obj.get() + model_d.unique_labels = (self.__unique_labels__buffer_).c_obj.get() + model_d.support_matrix.indptr = (self.__support_indptr_buffer_).c_obj.get() + model_d.support_matrix.indices = (self.__support_indices_buffer_).c_obj.get() + model_d.support_matrix.data = (self.__support_data_buffer_).c_obj.get() + @cuml.internals.api_base_return_any_skipall def _get_svm_model(self): """ Wrap the fitted model parameters into an SvmModel structure. @@ -410,7 +444,8 @@ class SVMBase(Base, """ cdef SvmModel[float] *model_f cdef SvmModel[double] *model_d - if self.dual_coef_ is None: + + if self.n_support_ is None: # the model is not fitted in this case return None if self.dtype == np.float32: @@ -418,46 +453,36 @@ class SVMBase(Base, model_f.n_support = self.n_support_ model_f.n_cols = self.n_cols model_f.b = self._intercept_.item() - model_f.dual_coefs = \ - self.dual_coef_.ptr - if isinstance(self.support_vectors_, SparseCumlArray): - model_f.support_matrix.nnz = self.support_vectors_.nnz - model_f.support_matrix.indptr = self.support_vectors_.indptr.ptr - model_f.support_matrix.indices = self.support_vectors_.indices.ptr - model_f.support_matrix.data = self.support_vectors_.data.ptr - else: - model_f.support_matrix.data = self.support_vectors_.ptr - model_f.support_idx = \ - self.support_.ptr + model_f.dual_coefs = (self.__dual_coef_buffer_).c_obj.get() + model_f.support_matrix.indptr = (self.__support_indptr_buffer_).c_obj.get() + model_f.support_matrix.indices = (self.__support_indices_buffer_).c_obj.get() + model_f.support_matrix.data = (self.__support_data_buffer_).c_obj.get() + model_f.support_idx = (self.__support_idx_buffer_).c_obj.get() + model_f.unique_labels = (self.__unique_labels__buffer_).c_obj.get() + + if self.__support_indptr_buffer_.size > 0: + model_f.support_matrix.nnz = (self.__support_data_buffer_.size // 4) + model_f.n_classes = self.n_classes_ - if self.n_classes_ > 0: - model_f.unique_labels = \ - self._unique_labels_.ptr - else: - model_f.unique_labels = NULL + return model_f else: model_d = new SvmModel[double]() model_d.n_support = self.n_support_ model_d.n_cols = self.n_cols model_d.b = self._intercept_.item() - model_d.dual_coefs = \ - self.dual_coef_.ptr - if isinstance(self.support_vectors_, SparseCumlArray): - model_d.support_matrix.nnz = self.support_vectors_.nnz - model_d.support_matrix.indptr = self.support_vectors_.indptr.ptr - model_d.support_matrix.indices = self.support_vectors_.indices.ptr - model_d.support_matrix.data = self.support_vectors_.data.ptr - else: - model_d.support_matrix.data = self.support_vectors_.ptr - model_d.support_idx = \ - self.support_.ptr + model_d.dual_coefs = (self.__dual_coef_buffer_).c_obj.get() + model_d.support_matrix.indptr = (self.__support_indptr_buffer_).c_obj.get() + model_d.support_matrix.indices = (self.__support_indices_buffer_).c_obj.get() + model_d.support_matrix.data = (self.__support_data_buffer_).c_obj.get() + model_d.support_idx = (self.__support_idx_buffer_).c_obj.get() + model_d.unique_labels = (self.__unique_labels__buffer_).c_obj.get() + + if self.__support_indptr_buffer_.size > 0: + model_d.support_matrix.nnz = (self.__support_data_buffer_.size // 8) + model_d.n_classes = self.n_classes_ - if self.n_classes_ > 0: - model_d.unique_labels = \ - self._unique_labels_.ptr - else: - model_d.unique_labels = NULL + return model_d def _unpack_svm_model(self, b, n_support, dual_coefs, support_idx, nnz, indptr, indices, data, n_classes, unique_labels): @@ -520,37 +545,32 @@ class SVMBase(Base, cdef SvmModel[float] *model_f cdef SvmModel[double] *model_d - # Mark that the C++ layer should free the parameter vectors - # If we could pass the deviceArray deallocator as finalizer for the - # device_array_from_ptr function, then this would not be necessary. - self._freeSvmBuffers = True - if self.dtype == np.float32: model_f = self._model self._unpack_svm_model( model_f.b, model_f.n_support, - model_f.dual_coefs, - model_f.support_idx, + model_f.dual_coefs.data(), + model_f.support_idx.data(), model_f.support_matrix.nnz, - model_f.support_matrix.indptr, - model_f.support_matrix.indices, - model_f.support_matrix.data, + model_f.support_matrix.indptr.data(), + model_f.support_matrix.indices.data(), + model_f.support_matrix.data.data(), model_f.n_classes, - model_f.unique_labels) + model_f.unique_labels.data()) else: model_d = self._model self._unpack_svm_model( model_d.b, model_d.n_support, - model_d.dual_coefs, - model_d.support_idx, + model_d.dual_coefs.data(), + model_d.support_idx.data(), model_d.support_matrix.nnz, - model_d.support_matrix.indptr, - model_d.support_matrix.indices, - model_d.support_matrix.data, + model_d.support_matrix.indptr.data(), + model_d.support_matrix.indices.data(), + model_d.support_matrix.data.data(), model_d.n_classes, - model_d.unique_labels) + model_d.unique_labels.data()) if self.n_support_ == 0: self.dual_coef_ = CumlArray.empty( @@ -679,6 +699,18 @@ class SVMBase(Base, state = self.__dict__.copy() del state['handle'] del state['_model'] + + # the following are only wrappers around data owned by the DeviceBuffers + # We don't want to serialize the data twice + if 'dual_coef_' in state: + del state['dual_coef_'] + if 'support_' in state: + del state['support_'] + if 'support_vectors_' in state: + del state['support_vectors_'] + if '_unique_labels_' in state: + del state['_unique_labels_'] + return state def __setstate__(self, state): @@ -686,4 +718,7 @@ class SVMBase(Base, verbose=state['verbose']) self.__dict__.update(state) self._model = self._get_svm_model() - self._freeSvmBuffers = False + + # unpack model & buffer locations + if self._model is not None: + self._unpack_model() diff --git a/python/cuml/cuml/svm/svr.pyx b/python/cuml/cuml/svm/svr.pyx index a2527f4358..e383b050d1 100644 --- a/python/cuml/cuml/svm/svr.pyx +++ b/python/cuml/cuml/svm/svr.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,6 +35,8 @@ from cuml.common import input_to_cuml_array from libcpp cimport nullptr from cuml.svm.svm_base import SVMBase +from rmm._lib.device_buffer cimport device_buffer + cdef extern from "cuml/matrix/kernelparams.h" namespace "MLCommon::Matrix": enum KernelType: LINEAR, POLYNOMIAL, RBF, TANH @@ -62,22 +64,22 @@ cdef extern from "cuml/svm/svm_parameter.h" namespace "ML::SVM": cdef extern from "cuml/svm/svm_model.h" namespace "ML::SVM": - cdef cppclass SupportStorage[math_t]: + cdef cppclass SupportStorage: int nnz - int* indptr - int* indices - math_t* data + device_buffer* indptr + device_buffer* indices + device_buffer* data cdef cppclass SvmModel[math_t]: # parameters of a fitted model int n_support int n_cols math_t b - math_t *dual_coefs - SupportStorage[math_t] support_matrix - int *support_idx + device_buffer* dual_coefs + SupportStorage support_matrix + device_buffer* support_idx int n_classes - math_t *unique_labels + device_buffer* unique_labels cdef extern from "cuml/svm/svr.hpp" namespace "ML::SVM" nogil: @@ -302,6 +304,8 @@ class SVR(SVMBase, RegressorMixin): if self.dtype == np.float32: model_f = new SvmModel[float]() + self._model = model_f + self._init_model_buffers() if is_sparse: svrFitSparse(handle_[0], X_indptr, X_indices, X_data, n_rows, n_cols, n_nnz, @@ -311,9 +315,10 @@ class SVR(SVMBase, RegressorMixin): svrFit(handle_[0], X_data, n_rows, n_cols, y_ptr, param, _kernel_params, model_f[0], sample_weight_ptr) - self._model = model_f elif self.dtype == np.float64: model_d = new SvmModel[double]() + self._model = model_d + self._init_model_buffers() if is_sparse: svrFitSparse(handle_[0], X_indptr, X_indices, X_data, n_rows, n_cols, n_nnz, @@ -323,7 +328,6 @@ class SVR(SVMBase, RegressorMixin): svrFit(handle_[0], X_data, n_rows, n_cols, y_ptr, param, _kernel_params, model_d[0], sample_weight_ptr) - self._model = model_d else: raise TypeError('Input data type should be float32 or float64') diff --git a/python/cuml/cuml/tests/test_pickle.py b/python/cuml/cuml/tests/test_pickle.py index 598ebbd7e3..3b87c0ce3e 100644 --- a/python/cuml/cuml/tests/test_pickle.py +++ b/python/cuml/cuml/tests/test_pickle.py @@ -714,10 +714,10 @@ def create_mod(): iris_selection = np.random.RandomState(42).choice( [True, False], 150, replace=True, p=[0.75, 0.25] ) - X_train = iris.data[iris_selection] + X_train = iris.data[iris_selection].astype(datatype) if sparse: - X_train = scipy_sparse.csr_matrix(X_train) - y_train = iris.target[iris_selection] + X_train = scipy_sparse.csr_matrix(X_train).astype(datatype) + y_train = iris.target[iris_selection].astype(datatype) if not multiclass: y_train = (y_train > 0).astype(datatype) data = [X_train, y_train] @@ -751,8 +751,8 @@ def create_mod(): iris_selection = np.random.RandomState(42).choice( [True, False], 150, replace=True, p=[0.75, 0.25] ) - X_train = iris.data[iris_selection] - y_train = iris.target[iris_selection] + X_train = iris.data[iris_selection].astype(datatype) + y_train = iris.target[iris_selection].astype(datatype) if not multiclass: y_train = (y_train > 0).astype(datatype) data = [X_train, y_train]