Skip to content

Commit

Permalink
chore: tri state of cuda graphs mode
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Nov 27, 2024
1 parent 711930f commit 32056f3
Show file tree
Hide file tree
Showing 12 changed files with 158 additions and 247 deletions.
4 changes: 0 additions & 4 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,6 @@ TRTEngine::TRTEngine(
LOG_DEBUG(*this);
}

void TRTEngine::set_whole_cudagraphs(bool enable) {
whole_cudagraphs = enable;
}

TRTEngine::~TRTEngine() {
trt_engine_profiler.reset();
exec_ctx.reset();
Expand Down
4 changes: 1 addition & 3 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ struct TRTEngine : torch::CustomClassHolder {
bool set_device_memory_budget(int64_t budget);
int64_t get_streamable_device_memory_budget();
int64_t get_automatic_device_memory_budget();
void set_whole_cudagraphs(bool enable);
std::vector<at::Tensor> infer_outputs(std::vector<std::vector<int64_t>> input_shapes);
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
static const char BINDING_DELIM = '%';
Expand All @@ -104,13 +103,12 @@ struct TRTEngine : torch::CustomClassHolder {
std::vector<at::Tensor> output_buffers = {};
std::string shape_key;
bool prev_cudagraphs_enabled = false;
bool whole_cudagraphs = false;
// TODO: Implement a call method
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);

void set_profiling_paths();
#ifndef NDEBUG
bool profile_execution = true;
bool profile_execution = false;
#else
bool profile_execution = false;
#endif
Expand Down
2 changes: 1 addition & 1 deletion core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
LOG_INFO("" << log_info);
compiled_engine->cudagraph.enable_debug_mode();
}
bool cudagraphs_enabled = (!compiled_engine->whole_cudagraphs && CUDAGRAPHS_MODE);
bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS);

// Whether cudagraphs needs to record the graph on this pass
// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
Expand Down
7 changes: 4 additions & 3 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
.def("set_whole_cudagraphs", &TRTEngine::set_whole_cudagraphs)
.def("infer_outputs", &TRTEngine::infer_outputs)
.def_property(
"device_memory_budget",
Expand All @@ -112,8 +111,10 @@ TORCH_LIBRARY(tensorrt, m) {
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
});
m.def("get_cudagraphs_mode", []() -> bool { return CUDAGRAPHS_MODE; });
m.def("set_cudagraphs_mode", [](bool cudagraphs_mode) -> void { CUDAGRAPHS_MODE = cudagraphs_mode; });
m.def("get_cudagraphs_mode", []() -> int64_t { return CUDAGRAPHS_MODE; });
m.def("set_cudagraphs_mode", [](int64_t cudagraphs_mode) -> void {
CUDAGRAPHS_MODE = CudaGraphsMode(cudagraphs_mode);
});
m.def("set_logging_level", [](int64_t level) -> void {
util::logging::get_logger().set_reportable_log_level(util::logging::LogLevel(level));
});
Expand Down
6 changes: 3 additions & 3 deletions core/runtime/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace core {
namespace runtime {

bool MULTI_DEVICE_SAFE_MODE = false;
bool CUDAGRAPHS_MODE = false;
CudaGraphsMode CUDAGRAPHS_MODE = STANDARD;

c10::optional<RTDevice> get_most_compatible_device(
const RTDevice& target_device,
Expand Down Expand Up @@ -130,11 +130,11 @@ void set_multi_device_safe_mode(bool multi_device_safe_mode) {
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
}

bool get_cudagraphs_mode() {
CudaGraphsMode get_cudagraphs_mode() {
return CUDAGRAPHS_MODE;
}

void set_cudagraphs_mode(bool cudagraphs_mode) {
void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode) {
CUDAGRAPHS_MODE = cudagraphs_mode;
}

Expand Down
13 changes: 10 additions & 3 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ namespace runtime {
using EngineID = int64_t;
const std::string ABI_VERSION = "6";
extern bool MULTI_DEVICE_SAFE_MODE;
extern bool CUDAGRAPHS_MODE;

typedef enum {
STANDARD = 0,
SUBGRAPH_CUDAGRAPHS,
WHOLE_GRAPH_CUDAGRAPHS,
} CudaGraphsMode;

extern CudaGraphsMode CUDAGRAPHS_MODE;

typedef enum {
ABI_TARGET_IDX = 0,
Expand Down Expand Up @@ -51,9 +58,9 @@ bool get_multi_device_safe_mode();

void set_multi_device_safe_mode(bool multi_device_safe_mode);

bool get_cudagraphs_mode();
CudaGraphsMode get_cudagraphs_mode();

void set_cudagraphs_mode(bool cudagraphs_mode);
void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode);

class DeviceList {
using DeviceMap = std::unordered_map<int, RTDevice>;
Expand Down
16 changes: 2 additions & 14 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ def __init__(
self.engine = None
self.weight_name_map = weight_name_map
self.target_platform = Platform.current_platform()
# Check if CUDA graph capture is enabled in the parent node
self.whole_cudagraphs = False
# Previous cuda graphs state
self.prev_cudagraphs_enabled = False

Expand Down Expand Up @@ -151,14 +149,6 @@ def set_default_device_memory_budget(self) -> int:
logger.debug(f"Weight streaming budget set to {budget_bytes}B")
return self._set_device_memory_budget(budget_bytes)

def set_whole_cudagraphs(self, enable: bool) -> None:
"""
When the global CUDA graphs mode is enabled, the parent wrapper module handles all
CUDA graph recording and replay. Therefore, any child modules must disable their
own CUDA graph functionality to avoid conflicts.
"""
self.whole_cudagraphs = enable

def setup_engine(self) -> None:
assert (
self.target_platform == Platform.current_platform()
Expand Down Expand Up @@ -257,10 +247,8 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
):
self._check_initialized()

cudagraphs_enabled = (
torch_tensorrt.runtime.get_cudagraphs_mode()
and not self.whole_cudagraphs
)
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()

# Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
need_cudagraphs_record = cudagraphs_enabled and (
(not self.prev_cudagraphs_enabled)
Expand Down
8 changes: 0 additions & 8 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,6 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:

return budget_bytes

def set_whole_cudagraphs(self, enable: bool) -> None:
"""
When the global CUDA graphs mode is enabled, the parent wrapper module handles all
CUDA graph recording and replay. Therefore, any child modules must disable their
own CUDA graph functionality to avoid conflicts.
"""
self.engine.set_whole_cudagraphs(enable)

def setup_engine(self) -> None:
"""
Setup engine for a module which has deferred engine setup.
Expand Down
Loading

0 comments on commit 32056f3

Please sign in to comment.