Skip to content

Commit

Permalink
feat(csrc): add callDiopiKeepContext utility (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
lljbash authored Jan 30, 2024
1 parent 141c74d commit 9a80a19
Showing 1 changed file with 32 additions and 18 deletions.
50 changes: 32 additions & 18 deletions csrc/diopi_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,19 @@ inline void checkTensorOnDevice(const c10::optional<at::Tensor>& tensor) {
// at::Tensor -> diopiTensorHandle_t
// const at::Tensor -> diopiConstTensorHandle_t
// c10::optional<at::Tensor> -> diopiConstTensorHandle_t
template <
class T, class U = std::decay_t<T>,
std::enable_if_t<std::is_same<U, at::Tensor>::value ||
std::is_same<U, c10::optional<at::Tensor>>::value,
int> = 0>
decltype(auto) castToDiopiType(T&& tensor) {
template <class T, class U = std::decay_t<T>,
std::enable_if_t<std::is_same_v<U, at::Tensor> ||
std::is_same_v<U, c10::optional<at::Tensor>>,
int> = 0>
[[nodiscard]] decltype(auto) castToDiopiType(T&& tensor) {
checkTensorOnDevice(tensor);
return diopi_helper::toDiopiTensorHandle(std::forward<T>(tensor));
}

// at::OptionalArrayRef -> diopiSize_t
template <
class T, class U = std::decay_t<T>,
std::enable_if_t<std::is_same<U, at::OptionalIntArrayRef>::value, int> = 0>
decltype(auto) castToDiopiType(T&& shape) {
template <class T, class U = std::decay_t<T>,
std::enable_if_t<std::is_same_v<U, at::OptionalIntArrayRef>, int> = 0>
[[nodiscard]] decltype(auto) castToDiopiType(T&& shape) {
return diopi_helper::toDiopiSize(std::forward<T>(shape));
}

Expand All @@ -69,17 +67,17 @@ template <class T, class U = std::decay_t<T>,
std::enable_if_t<std::is_same<U, at::Generator>() ||
std::is_same<U, c10::optional<at::Generator>>(),
int> = 0>
decltype(auto) castToDiopiType(T&& gen) {
[[nodiscard]] decltype(auto) castToDiopiType(T&& gen) {
return diopi_helper::toDiopiGeneratorHandle(std::forward<T>(gen));
}

// c10::optional<ArithmeticType> -> const ArithmeticType*
template <
class T, class U = std::decay_t<T>,
std::enable_if_t<type_traits::IsOptionalArithmetic<U>::value, int> = 0>
auto castToDiopiType(T&& opt) -> const typename U::value_type* {
[[nodiscard]] auto castToDiopiType(T&& opt) -> const typename U::value_type* {
if (opt) {
return &(*opt);
return &(*std::forward<T>(opt));
}
return nullptr;
}
Expand All @@ -88,15 +86,19 @@ auto castToDiopiType(T&& opt) -> const typename U::value_type* {
// Pointer -> Pointer
template <
class T, class U = std::decay_t<T>,
std::enable_if_t<std::is_arithmetic<U>::value || std::is_pointer<U>::value,
int> = 0>
decltype(auto) castToDiopiType(T&& arg) {
std::enable_if_t<std::is_arithmetic_v<U> || std::is_pointer_v<U>, int> = 0>
[[nodiscard]] decltype(auto) castToDiopiType(T&& arg) {
return std::forward<T>(arg);
}

// NOTE: This function will keep the context in the upper stack frame.
// You usually don't need to explicit use the return value, i.e.
// `[[maybe_unused]] auto context = callDiopiKeepContext(...);`
// is what you should do in most cases.
template <class DiopiFunc, class... Args>
void callDiopi(DiopiFunc&& diopi_func, Args&&... args) {
static_assert(std::is_function<std::remove_reference_t<DiopiFunc>>::value,
[[nodiscard]] diopiContext callDiopiKeepContext(const DiopiFunc& diopi_func,
Args&&... args) {
static_assert(std::is_function_v<std::remove_reference_t<DiopiFunc>>,
"DiopiFunc must be a function");
diopiContext ctx(dipu::getCurrentDIPUStream().rawstream());
diopiError_t err_code =
Expand All @@ -105,6 +107,18 @@ void callDiopi(DiopiFunc&& diopi_func, Args&&... args) {
throw std::runtime_error("DIOPI error, code: " + std::to_string(err_code) +
", message: " + diopiGetLastErrorString());
}
return ctx;
}

// WARNING: This function will destruct the context after the function call. If
// you need to keep the context (e.g. casting a `diopiTensorHandle_t*`
// allocated by `diopiRequireTensor` during the `diopi_func` call to a
// `at::Tensor` and then using it somewhere outside), please use
// `callDiopiKeepContext` instead.
template <class DiopiFunc, class... Args>
void callDiopi(const DiopiFunc& diopi_func, Args&&... args) {
[[maybe_unused]] auto context =
callDiopiKeepContext(diopi_func, std::forward<Args>(args)...);
}

} // namespace dipu_ext
Expand Down

0 comments on commit 9a80a19

Please sign in to comment.