Skip to content

Commit

Permalink
refactor: 尝试直接调用 cuda runtime api
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 8, 2024
1 parent 535134b commit 0bc8305
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 59 deletions.
8 changes: 4 additions & 4 deletions src/02hardware/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ cmake_minimum_required(VERSION 3.12 FATAL_ERROR)
project(hardware VERSION 0.0.0 LANGUAGES CXX)
message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION})

if(USE_CUDA)
file(GLOB_RECURSE HARDWARE_CUDA_SRC src/*.cu)
endif()

file(GLOB_RECURSE HARDWARE_SRC src/*.cc src/*.cpp)
add_library(hardware STATIC ${HARDWARE_SRC} ${HARDWARE_CUDA_SRC})
target_link_libraries(hardware PUBLIC common)
target_include_directories(hardware PUBLIC include)

if(USE_CUDA)
target_include_directories(hardware PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
endif()

file(GLOB_RECURSE HARDWARE_TEST test/*.cpp)
if(HARDWARE_TEST)
add_executable(hardware_test ${HARDWARE_TEST})
Expand Down
2 changes: 1 addition & 1 deletion src/02hardware/include/hardware/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ namespace refactor::hardware {

virtual ~Device() = default;
virtual Type type() const noexcept = 0;
virtual void setContext() const noexcept;
virtual void setContext() const;

Arc<Blob> malloc(size_t);
Arc<Blob> absorb(Arc<Blob> &&);
Expand Down
2 changes: 1 addition & 1 deletion src/02hardware/include/hardware/devices/nvidia.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace refactor::hardware {
class Nvidia final : public Device {
public:
explicit Nvidia(int32_t card);
void setContext() const noexcept final;
void setContext() const final;
Type type() const noexcept final {
return Type::Nvidia;
}
Expand Down
2 changes: 1 addition & 1 deletion src/02hardware/src/device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ namespace refactor::hardware {
Device::Device(decltype(_card) card, decltype(_mem) mem)
: _card(card), _mem(std::move(mem)) {}

void Device::setContext() const noexcept {}
void Device::setContext() const {}
auto Device::malloc(size_t size) -> Arc<Blob> {
return Arc<Blob>(new Blob(this, size));
}
Expand Down
25 changes: 18 additions & 7 deletions src/02hardware/src/devices/nvidia/device.cc
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
#include "hardware/devices/nvidia.h"
#include "hardware/mem_pool.h"

#ifdef USE_CUDA
#include "functions.cuh"
#include "memory.cuh"
#include "memory.hh"
#include <cuda_runtime.h>

#define CUDA_ASSERT(STATUS) \
if (auto status = (STATUS); status != cudaSuccess) { \
RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \
cudaGetErrorString(status), (int) status)); \
}
#endif

namespace refactor::hardware {

static Arc<Memory> cudaMemory(int32_t card) {
#ifdef USE_CUDA
ASSERT(0 <= card && card < getDeviceCount(), "Invalid card id: {}", card);
setDevice(card);
auto [free, total] = getMemInfo();
int deviceCount;
CUDA_ASSERT(cudaGetDeviceCount(&deviceCount));
ASSERT(0 <= card && card < deviceCount, "Invalid card id: {}", card);
CUDA_ASSERT(cudaSetDevice(card));

size_t free, total;
CUDA_ASSERT(cudaMemGetInfo(&free, &total));
auto size = std::min(free, std::max(5ul << 30, total * 4 / 5));
fmt::println("initializing Nvidia GPU {}, memory {} / {}, alloc {}",
card, free, total, size);
Expand All @@ -26,9 +37,9 @@ namespace refactor::hardware {

Nvidia::Nvidia(int32_t card) : Device(card, cudaMemory(card)) {}

void Nvidia::setContext() const noexcept {
void Nvidia::setContext() const {
#ifdef USE_CUDA
setDevice(_card);
CUDA_ASSERT(cudaSetDevice(_card));
#endif
}

Expand Down
19 changes: 0 additions & 19 deletions src/02hardware/src/devices/nvidia/functions.cu

This file was deleted.

24 changes: 0 additions & 24 deletions src/02hardware/src/devices/nvidia/functions.cuh

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
#include "functions.cuh"
#include "memory.cuh"
#ifdef USE_CUDA

#include "memory.hh"
#include "common.h"
#include <cuda_runtime.h>

#define CUDA_ASSERT(STATUS) \
if (auto status = (STATUS); status != cudaSuccess) { \
RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \
cudaGetErrorString(status), (int) status)); \
}

namespace refactor::hardware {
using M = NvidiaMemory;
Expand Down Expand Up @@ -29,3 +38,5 @@ namespace refactor::hardware {
}

}// namespace refactor::hardware

#endif
File renamed without changes.

0 comments on commit 0bc8305

Please sign in to comment.