Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

k2 build without pytorch #1164

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
13 changes: 11 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,18 @@ if(NOT K2_WITH_CUDA)
set(K2_ENABLE_NVTX OFF CACHE BOOL "" FORCE)
endif()

if(K2_WITH_CUDA)
find_package(CUDA QUIET)
message(STATUS "CUDA_INCLUDE_DIRS=${CUDA_INCLUDE_DIRS}")
message(STATUS "CUDA_LIBRARIES=${CUDA_LIBRARIES}")
message(STATUS "CUDA_nvToolsExt_LIBRARY=${CUDA_nvToolsExt_LIBRARY}")
message(STATUS "CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}")
endif()

zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
if(NOT K2_USE_PYTORCH)
message(FATAL_ERROR "\
Please set K2_USE_PYTORCH to ON.
message(WARNING "\
K2_USE_PYTORCH is OFF, only k2 core lib will be build.
If you want using with PyTorch, please turn it ON.
Support for other frameworks will be added later")
endif()

Expand Down
4 changes: 3 additions & 1 deletion k2/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
add_subdirectory(csrc)
add_subdirectory(python)
if(K2_USE_PYTORCH)
add_subdirectory(python)
endif()

if(K2_USE_PYTORCH)
# We use K2_TORCH_VERSION instead of TORCH_VERSION
Expand Down
5 changes: 3 additions & 2 deletions k2/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ add_library(k2_nvtx INTERFACE)
target_include_directories(k2_nvtx INTERFACE ${CMAKE_SOURCE_DIR})
if(K2_ENABLE_NVTX)
target_compile_definitions(k2_nvtx INTERFACE K2_ENABLE_NVTX=1)
target_include_directories(k2_nvtx INTERFACE ${CUDA_INCLUDE_DIRS})
target_link_libraries(k2_nvtx INTERFACE ${CUDA_nvToolsExt_LIBRARY})
if(WIN32)
target_include_directories(k2_nvtx INTERFACE
${CUDA_TOOLKIT_ROOT_DIR}/include/nvtx3
Expand Down Expand Up @@ -79,14 +81,13 @@ set(context_srcs
thread_pool.cu
timer.cu
top_sort.cu
torch_util.cu
utils.cu
nbest.cu
)


if(K2_USE_PYTORCH)
list(APPEND context_srcs pytorch_context.cu)
list(APPEND context_srcs pytorch_context.cu torch_util.cu)
else()
list(APPEND context_srcs default_context.cu)
endif()
Expand Down
50 changes: 48 additions & 2 deletions k2/csrc/default_context.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "k2/csrc/context.h"
#include "k2/csrc/log.h"
#include "k2/csrc/nvtx.h"
#include "k2/csrc/device_guard.h"

namespace k2 {

Expand All @@ -32,7 +33,7 @@ static constexpr std::size_t kAlignment = 64;
class CpuContext : public Context {
public:
CpuContext() = default;
ContextPtr GetCpuContext() override { return shared_from_this(); }
ContextPtr GetCpuContext() { return shared_from_this(); }
DeviceType GetDeviceType() const override { return kCpu; }

void *Allocate(std::size_t bytes, void **deleter_context) override {
Expand All @@ -52,6 +53,28 @@ class CpuContext : public Context {
void Deallocate(void *data, void * /*deleter_context*/) override {
free(data);
}

void CopyDataTo(size_t num_bytes, const void *src, ContextPtr dst_context,
void *dst) override {
DeviceType device_type = dst_context->GetDeviceType();
switch (device_type) {
case kCpu:
memcpy(dst, src, num_bytes);
break;
case kCuda: {
// CPU -> CUDA
DeviceGuard guard(dst_context);
ContextPtr pinned_context = GetPinnedContext();
auto region = NewRegion(pinned_context, num_bytes);
memcpy(region->data, src, num_bytes);
pinned_context->CopyDataTo(num_bytes, region->data, dst_context, dst);
break;
}
default:
K2_LOG(FATAL) << "Unsupported device type: " << device_type;
break;
}
}
};

class CudaContext : public Context {
Expand All @@ -66,7 +89,7 @@ class CudaContext : public Context {
auto ret = cudaStreamCreate(&stream_);
K2_CHECK_CUDA_ERROR(ret);
}
ContextPtr GetCpuContext() override { return k2::GetCpuContext(); }
ContextPtr GetCpuContext() { return k2::GetCpuContext(); }
DeviceType GetDeviceType() const override { return kCuda; }
int32_t GetDeviceId() const override { return gpu_id_; }

Expand Down Expand Up @@ -98,6 +121,29 @@ class CudaContext : public Context {
K2_CHECK_CUDA_ERROR(ret);
}

void CopyDataTo(size_t num_bytes, const void *src, ContextPtr dst_context,
void *dst) override {
DeviceType device_type = dst_context->GetDeviceType();
switch (device_type) {
case kCpu: {
cudaError_t ret =
cudaMemcpy(dst, src, num_bytes, cudaMemcpyDeviceToHost);
K2_CHECK_CUDA_ERROR(ret);
break;
}
case kCuda: {
cudaError_t ret =
cudaMemcpyAsync(dst, src, num_bytes, cudaMemcpyDeviceToDevice,
dst_context->GetCudaStream());
K2_CHECK_CUDA_ERROR(ret);
break;
}
default:
K2_LOG(FATAL) << "Unsupported device type: " << device_type;
break;
}
}

~CudaContext() {
auto ret = cudaStreamDestroy(stream_);
K2_CHECK_CUDA_ERROR(ret);
Expand Down