Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Update xgboost comparison notebook with 6 new datasets, update legate-core to 24.09 #142

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ find_package(BLAS REQUIRED)
legate_add_cpp_subdirectory(src TARGET legateboost EXPORT legateboost-export)

legate_add_cffi(${CMAKE_SOURCE_DIR}/src/legateboost.h TARGET legateboost)
legate_python_library_template(legateboost)
legate_default_python_install(legateboost EXPORT legateboost-export)


Expand Down
3 changes: 3 additions & 0 deletions depth_16_gpu_time.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions examples/notebook/xgboost_comparison.ipynb
Git LFS file not shown
4 changes: 2 additions & 2 deletions legateboost/legateboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .models import BaseModel, Tree
from .objectives import BaseObjective, objectives
from .shapley import global_shapley_attributions, local_shapley_attributions
from .utils import PickleCunumericMixin, preround
from .utils import PickleCunumericMixin

if TYPE_CHECKING:
from .callbacks import TrainingCallback
Expand Down Expand Up @@ -200,7 +200,7 @@ def _get_weighted_gradient(
g *= mask[:, None]
h *= mask[:, None]

return preround(g), preround(h)
return g, h

def _partial_fit(
self,
Expand Down
91 changes: 91 additions & 0 deletions legateboost/library.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import os
import platform
from ctypes import CDLL, RTLD_GLOBAL
from typing import Any

from cffi import FFI

from legate.core import get_legate_runtime


def dlopen_no_autoclose(ffi: Any, lib_path: str) -> Any:
# Use an already-opened library handle, which cffi will convert to a
# regular FFI object (using the definitions previously added using
# ffi.cdef), but will not automatically dlclose() on collection.
lib = CDLL(lib_path, mode=RTLD_GLOBAL)
return ffi.dlopen(ffi.cast("void *", lib._handle))


class UserLibrary:
def __init__(self, name: str) -> None:
self.name = name
self.shared_object: Any = None

shared_lib_path = self.get_shared_library()
if shared_lib_path is not None:
header = self.get_c_header()
ffi = FFI()
if header is not None:
ffi.cdef(header)
# Don't use ffi.dlopen(), because that will call dlclose()
# automatically when the object gets collected, thus removing
# symbols that may be needed when destroying C++ objects later
# (e.g. vtable entries, which will be queried for virtual
# destructors), causing errors at shutdown.
shared_lib = dlopen_no_autoclose(ffi, shared_lib_path)
self.initialize(shared_lib)
callback_name = self.get_registration_callback()
callback = getattr(shared_lib, callback_name)
callback()
else:
self.initialize(None)

@property
def cffi(self) -> Any:
return self.shared_object

def get_name(self) -> str:
return self.name

def get_shared_library(self) -> str:
from legateboost.install_info import libpath

return os.path.join(libpath, f"liblegateboost{self.get_library_extension()}")

def get_c_header(self) -> str:
from legateboost.install_info import header

return header

def get_registration_callback(self) -> str:
return "legateboost_perform_registration"

def initialize(self, shared_object: Any) -> None:
self.shared_object = shared_object

def destroy(self) -> None:
pass

@staticmethod
def get_library_extension() -> str:
os_name = platform.system()
if os_name == "Linux":
return ".so"
elif os_name == "Darwin":
return ".dylib"
raise RuntimeError(f"unknown platform {os_name!r}")


user_lib = UserLibrary("legateboost")
user_context = get_legate_runtime().find_library(user_lib.get_name())
10 changes: 2 additions & 8 deletions legateboost/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
NormalLLMetric,
QuantileMetric,
)
from .utils import mod_col_by_idx, preround, sample_average, set_col_by_idx
from .utils import mod_col_by_idx, sample_average, set_col_by_idx

GradPair: TypeAlias = Tuple[cn.ndarray, cn.ndarray]

Expand Down Expand Up @@ -110,8 +110,6 @@ def initialise_prediction(
) -> cn.ndarray:
assert y.ndim == 2
if boost_from_average:
y = preround(y)
w = preround(w)
return cn.sum(y * w[:, None], axis=0) / cn.sum(w)
else:
return cn.zeros(y.shape[1])
Expand Down Expand Up @@ -197,11 +195,9 @@ def initialise_prediction(
assert y.ndim == 2
pred = cn.zeros((y.shape[1], 2))
if boost_from_average:
y = preround(y)
w = preround(w)
mean = cn.sum(y * w[:, None], axis=0) / cn.sum(w)
var = (y - mean) * (y - mean) * w[:, None]
var = cn.sum(preround(var), axis=0) / cn.sum(w)
var = cn.sum(var, axis=0) / cn.sum(w)
pred[:, 0] = mean
pred[:, 1] = cn.log(var) / 2
return pred.reshape(-1)
Expand Down Expand Up @@ -438,8 +434,6 @@ def initialise_prediction(
# Instead fit a normal distribution to the data and use that
# to estimate quantiles.
if boost_from_average:
y = preround(y)
w = preround(w)
mean = cn.sum(y * w[:, None], axis=0) / cn.sum(w)
var = cn.sum((y - mean) * (y - mean) * w[:, None], axis=0) / cn.sum(w)
init = cn.array(
Expand Down
6 changes: 1 addition & 5 deletions legateboost/test/models/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@

import cunumeric as cn
import legateboost as lb
from legateboost.testing.utils import check_determinism, non_increasing


@pytest.mark.parametrize("max_depth", [0, 8])
def test_determinism(max_depth):
check_determinism(lb.models.Tree(max_depth=max_depth))
from ..test_utils import non_increasing


def test_basic():
Expand Down
4 changes: 2 additions & 2 deletions legateboost/test/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ def test_update(init):
update_train_loss = metric.metric(y, model.predict(X), cn.ones(y.shape[0]))
assert update_train_loss < half_data_train_loss

# check that updating with same dataset results in exact same model
# check that updating with same dataset results in same model
model.fit(X, y)
pred = model.predict(X)
model.update(X, y)
updated_pred = model.predict(X)
assert (pred == updated_pred).all()
assert np.allclose(pred, updated_pred)


@pytest.mark.parametrize("num_outputs", [1, 5])
Expand Down
23 changes: 0 additions & 23 deletions legateboost/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,29 +74,6 @@ def mod_col_by_idx(a: cn.ndarray, b: cn.ndarray, delta: float) -> None:
return


one = cn.array(1.0, dtype=cn.float64)
two = cn.array(2.0, dtype=cn.float64)
eps = cn.finfo(cn.float64).eps


def preround(x: cn.ndarray) -> cn.ndarray:
"""Apply this function to grad/hess ensure reproducible floating point
summation.

Algorithm 5: Reproducible Sequential Sum in 'Fast Reproducible
Floating-Point Summation' by Demmel and Nguyen.

Instead of using max(abs(x)) * n as an upper bound we use sum(abs(x))
"""
assert x.dtype == cn.float32 or x.dtype == cn.float64
m = cn.sum(cn.abs(x))
n = x.size
delta = cn.floor(m / (one - two * n * eps))
delta = cn.maximum(delta, one)
M = two ** cn.ceil(cn.log2(delta))
return (x + M) - M


def get_store(input: Any) -> LogicalStore:
"""Extracts a Legate store from any object implementing the legate data
interface.
Expand Down
18 changes: 0 additions & 18 deletions src/cpp_utils/cpp_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -90,24 +90,6 @@ void SumAllReduce(legate::TaskContext context, T* x, int count, cudaStream_t str
}
}

#if __CUDA_ARCH__ < 600
__device__ inline double atomicAdd(double* address, double val)
{
unsigned long long int* address_as_ull = (unsigned long long int*)address;
unsigned long long int old = *address_as_ull, assumed;

do {
assumed = old;
old =
atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)));

// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
} while (assumed != old);

return __longlong_as_double(old);
}
#endif

#if THRUST_VERSION >= 101600
#define DEFAULT_POLICY thrust::cuda::par_nosync
#else
Expand Down
43 changes: 35 additions & 8 deletions src/cpp_utils/cpp_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ void expect_axis_aligned(const ShapeAT& a, const ShapeBT& b, std::string file, i
#define EXPECT_AXIS_ALIGNED(axis, shape_a, shape_b) \
(expect_axis_aligned<axis>(shape_a, shape_b, __FILE__, __LINE__))

template <typename ShapeT>
void expect_is_broadcast(const ShapeT& shape, std::string file, int line)
template <int DIM>
void expect_is_broadcast(const legate::Rect<DIM>& shape, std::string file, int line)
{
for (int i = 0; i < sizeof(shape.lo.x) / sizeof(shape.lo[0]); i++) {
for (int i = 0; i < DIM; i++) {
std::stringstream ss;
ss << "Expected a broadcast store. Got shape: " << shape << ".";
expect(shape.lo[i] == 0, ss.str(), file, line);
Expand Down Expand Up @@ -131,19 +131,16 @@ void SumAllReduce(legate::TaskContext context, T* x, int count)
std::vector<T> data(items_per_rank * num_ranks);
std::copy(x, x + count, data.begin());
std::vector<T> recvbuf(items_per_rank * num_ranks);
auto result =
legate::comm::coll::collAlltoall(data.data(), recvbuf.data(), items_per_rank, type, comm_ptr);
EXPECT(result == legate::comm::coll::CollSuccess, "CPU communicator failed.");
legate::comm::coll::collAlltoall(data.data(), recvbuf.data(), items_per_rank, type, comm_ptr);

// Sum partials
std::vector<T> partials(items_per_rank, 0.0);
for (size_t j = 0; j < items_per_rank; j++) {
for (size_t i = 0; i < num_ranks; i++) { partials[j] += recvbuf[i * items_per_rank + j]; }
}

result = legate::comm::coll::collAllgather(
legate::comm::coll::collAllgather(
partials.data(), recvbuf.data(), items_per_rank, type, comm_ptr);
EXPECT(result == legate::comm::coll::CollSuccess, "CPU communicator failed.");
std::copy(recvbuf.begin(), recvbuf.begin() + count, x);
}

Expand Down Expand Up @@ -201,12 +198,42 @@ class UnravelIter {
return copy;
}

template <typename DistanceT>
__host__ __device__ UnravelIter operator-(DistanceT n)
{
UnravelIter copy = *this;
copy -= n;
return copy;
}

__host__ __device__ bool operator!=(UnravelIter const& other) const
{
return current_ != other.current_;
}

__host__ __device__ bool operator==(UnravelIter const& other) const
{
return current_ == other.current_;
}

__host__ __device__ UnravelIter operator-(UnravelIter other) const
{
other.current_ = current_ - other.current_;
return other;
}

template <typename DistanceT>
__host__ __device__ value_type operator[](DistanceT n) const
{
return UnravelIndex(current_ + n, shape_);
}

template <typename DistanceT>
__host__ __device__ value_type operator()(DistanceT n) const
{
return this->operator[](n);
}

__host__ __device__ value_type operator*() const { return UnravelIndex(current_, shape_); }
};

Expand Down
4 changes: 2 additions & 2 deletions src/legate_library.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ struct Registry {

template <typename T, int ID>
struct Task : public legate::LegateTask<T> {
using Registrar = Registry;
static constexpr int TASK_ID = ID;
using Registrar = Registry;
static constexpr auto TASK_ID = legate::LocalTaskID{ID};
};

} // namespace legateboost
13 changes: 6 additions & 7 deletions src/models/nn/build_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ void SyncCPU(legate::TaskContext context)
auto comm_ptr = comm.get<legate::comm::coll::CollComm>();
EXPECT(comm_ptr != nullptr, "CPU communicator is null.");
float tmp;
auto result = legate::comm::coll::collAllgather(
legate::comm::coll::collAllgather(
&tmp, gather_result.data(), 1, legate::comm::coll::CollDataType::CollFloat, comm_ptr);
EXPECT(result == legate::comm::coll::CollSuccess, "CPU communicator failed.");
}

// Store handles to legate and cublas
Expand Down Expand Up @@ -271,20 +270,20 @@ T eval_cost(NNContext* context,
cub::DeviceReduce::Sum(nullptr,
temp_storage_bytes,
cost_array.data,
result.ptr({0}),
result.ptr(0),
cost_array.size(),
context->stream);
auto temp_storage = legate::create_buffer<int8_t>({temp_storage_bytes});
cub::DeviceReduce::Sum(temp_storage.ptr({0}),
cub::DeviceReduce::Sum(temp_storage.ptr(0),
temp_storage_bytes,
cost_array.data,
result.ptr({0}),
result.ptr(0),
cost_array.size(),
context->stream);
SumAllReduce(context->legate_context, result.ptr({0}), 1, context->stream);
SumAllReduce(context->legate_context, result.ptr(0), 1, context->stream);

T cost;
cudaMemcpyAsync(&cost, result.ptr({0}), sizeof(T), cudaMemcpyDeviceToHost, context->stream);
cudaMemcpyAsync(&cost, result.ptr(0), sizeof(T), cudaMemcpyDeviceToHost, context->stream);
CHECK_CUDA(cudaStreamSynchronize(context->stream));
if (alpha > 0.0) {
T L2 = 0.0;
Expand Down
Loading
Loading