Skip to content

Commit

Permalink
chore: Functionalize inputs setup
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Nov 25, 2024
1 parent 4a5f0d1 commit f480353
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 115 deletions.
139 changes: 72 additions & 67 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,77 @@ bool _validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngi

return false;
}
void setup_input_tensors(
std::vector<at::Tensor> inputs,
c10::intrusive_ptr<TRTEngine> compiled_engine,
bool need_cudagraphs_record) {
// this is a buffer to store shape tensor input addresses throughout the runtime scope
std::list<std::vector<int64_t>> inputShapeTensorValues;
std::list<at::Tensor> formatted_inputs(compiled_engine->num_io.first);

for (size_t i = 0; i < inputs.size(); i++) {
std::string name = compiled_engine->in_binding_names[i];

TORCHTRT_CHECK(
inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());

auto expected_type =
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
TORCHTRT_CHECK(
inputs[i].dtype() == expected_type,
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());

auto dims = core::util::toDims(inputs[i].sizes());
auto shape = core::util::toVec(dims);
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);

if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) {
// Shape tensor inputs are casted to int64 explicitly.
// Refer to
// https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt64);
std::vector<int64_t> inputs_cpu_vec(
input_cpu.data_ptr<int64_t>(), input_cpu.data_ptr<int64_t>() + input_cpu.numel());
inputShapeTensorValues.emplace_back(inputs_cpu_vec);
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),
"Error while setting the tensor address for shape inputs");

if (CUDAGRAPHS_MODE) {
// @peri044 I dont know if this makes sense since they are supposed to be GPU buffers
compiled_engine->input_buffers[i] = input_cpu;
}
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),
"Error while setting the tensor address for shape inputs");

} else {
at::Tensor contig_input = inputs[i].view(shape).contiguous();
formatted_inputs.emplace_back(std::move(contig_input));

if (need_cudagraphs_record) {
// Create a new persistent input buffer
compiled_engine->input_buffers[i] = std::move(formatted_inputs.back().clone());
}

TORCHTRT_CHECK(
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape");

if (CUDAGRAPHS_MODE) {
// If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true);
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), compiled_engine->input_buffers[i].data_ptr()),
"Error while setting the input tensor address for inputs");
} else {
// Otherwise use the formatted buffer directly
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), formatted_inputs.back().data_ptr()),
"Error while setting the input tensor address for inputs");
}
}
}
}
std::vector<at::Tensor> create_output_tensors(c10::intrusive_ptr<TRTEngine> compiled_engine) {
std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
for (auto output_indices : compiled_engine->out_binding_map) {
Expand Down Expand Up @@ -142,11 +212,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
compiled_engine->cudagraph.reset();
}

// this is a buffer to store shape tensor input addresses throughout the runtime scope
std::list<std::vector<int64_t>> inputShapeTensorValues;

// Intialize inputs and outputs to be available throughout the succeeding scopes
std::list<at::Tensor> formatted_inputs(compiled_engine->num_io.first);
std::vector<at::Tensor> outputs(compiled_engine->num_io.second);

if (MULTI_DEVICE_SAFE_MODE) {
Expand Down Expand Up @@ -204,68 +270,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
}

for (size_t i = 0; i < inputs.size(); i++) {
std::string name = compiled_engine->in_binding_names[i];

TORCHTRT_CHECK(
inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());

auto expected_type =
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
TORCHTRT_CHECK(
inputs[i].dtype() == expected_type,
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());

auto dims = core::util::toDims(inputs[i].sizes());
auto shape = core::util::toVec(dims);
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);

if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) {
// Shape tensor inputs are casted to int64 explicitly.
// Refer to
// https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt64);
std::vector<int64_t> inputs_cpu_vec(
input_cpu.data_ptr<int64_t>(), input_cpu.data_ptr<int64_t>() + input_cpu.numel());
inputShapeTensorValues.emplace_back(inputs_cpu_vec);
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),
"Error while setting the tensor address for shape inputs");

if (CUDAGRAPHS_MODE) {
// @peri044 I dont know if this makes sense since they are supposed to be GPU buffers
compiled_engine->input_buffers[i] = input_cpu;
}
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),
"Error while setting the tensor address for shape inputs");

} else {
at::Tensor contig_input = inputs[i].view(shape).contiguous();
formatted_inputs.emplace_back(std::move(contig_input));

if (need_cudagraphs_record) {
// Create a new persistent input buffer
compiled_engine->input_buffers[i] = std::move(formatted_inputs.back().clone());
}

TORCHTRT_CHECK(
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape");

if (CUDAGRAPHS_MODE) {
// If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true);
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), compiled_engine->input_buffers[i].data_ptr()),
"Error while setting the input tensor address for inputs");
} else {
// Otherwise use the formatted buffer directly
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), formatted_inputs.back().data_ptr()),
"Error while setting the input tensor address for inputs");
}
}
}
setup_input_tensors(inputs, compiled_engine, need_cudagraphs_record);

// Check if input shapes can be inferred.
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
Expand All @@ -284,7 +289,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
output_profiler_guard =
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->output_profile_path);
}
if ((false == compiled_engine->use_pre_allocated_outputs) || shape_changed) {
if (!compiled_engine->use_pre_allocated_outputs || shape_changed) {
outputs = create_output_tensors(compiled_engine);
} else {
outputs = compiled_engine->pre_allocated_outputs;
Expand Down
104 changes: 56 additions & 48 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
self.target_platform = Platform.current_platform()
self.prev_cudagraphs_enabled = False
self.pre_allocated_outputs: List[torch.Tensor] = []
self.use_pre_allocated_outputs = True
self.use_pre_allocated_outputs = False

if self.serialized_engine is not None and not self.settings.lazy_engine_init:
self.setup_engine()
Expand Down Expand Up @@ -235,6 +235,57 @@ def __del__(self) -> None:
if self.cudagraph:
self.cudagraph.reset()

def setup_input_tensors(
self,
contiguous_inputs: List[torch.Tensor],
cudagraphs_enabled: bool,
need_cudagraphs_record: bool,
) -> None:
for i, input_name in enumerate(self.input_names):
if not contiguous_inputs[i].is_cuda:
logger.warning(
f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. "
"This tensor is being moved by the runtime but for performance considerations, "
"ensure your inputs are all on GPU and open an issue here "
"(https://github.com/pytorch/TensorRT/issues) if this warning persists."
)
contiguous_inputs = (
contiguous_inputs[:i]
+ [contiguous_inputs[i].cuda()]
+ contiguous_inputs[i + 1 :]
)

assert (
contiguous_inputs[i].dtype == self.input_dtypes[i]
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."

if need_cudagraphs_record:
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
# Clone is required to avoid re-using user-provided GPU memory
self._input_buffers[i] = contiguous_inputs[i].clone()

# For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers
# as per TensorRT requirements
if self.engine.is_shape_inference_io(input_name):
# Shape tensor inputs are casted to int64 explicitly
# Currently Torch CPU pointers are not working; numpy pointers are used instead
# to refer to underlying memory
inputs_cpu = contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
self.context.set_tensor_address(input_name, inputs_cpu.ctypes.data)
else:
self.context.set_input_shape(
input_name, tuple(contiguous_inputs[i].shape)
)
if cudagraphs_enabled:
self._input_buffers[i].copy_(contiguous_inputs[i])
self.context.set_tensor_address(
input_name, self._input_buffers[i].data_ptr()
)
else:
self.context.set_tensor_address(
input_name, contiguous_inputs[i].data_ptr()
)

def create_output_tensors(self) -> List[torch.Tensor]:
# create output tensors
outputs: List[torch.Tensor] = []
Expand Down Expand Up @@ -272,6 +323,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
need_cudagraphs_record = True
else:
need_cudagraphs_record = cudagraphs_enabled and shape_changed

self.prev_cudagraphs_enabled = cudagraphs_enabled

if need_cudagraphs_record:
Expand Down Expand Up @@ -327,54 +379,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
self.input_names
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}."

for i, input_name in enumerate(self.input_names):
if not contiguous_inputs[i].is_cuda:
logger.warning(
f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. "
"This tensor is being moved by the runtime but for performance considerations, "
"ensure your inputs are all on GPU and open an issue here "
"(https://github.com/pytorch/TensorRT/issues) if this warning persists."
)
contiguous_inputs = (
contiguous_inputs[:i]
+ [contiguous_inputs[i].cuda()]
+ contiguous_inputs[i + 1 :]
)

assert (
contiguous_inputs[i].dtype == self.input_dtypes[i]
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."
self.setup_input_tensors(
contiguous_inputs, cudagraphs_enabled, need_cudagraphs_record
)

if need_cudagraphs_record:
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
# Clone is required to avoid re-using user-provided GPU memory
self._input_buffers[i] = contiguous_inputs[i].clone()

# For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers
# as per TensorRT requirements
if self.engine.is_shape_inference_io(input_name):
# Shape tensor inputs are casted to int64 explicitly
# Currently Torch CPU pointers are not working; numpy pointers are used instead
# to refer to underlying memory
inputs_cpu = (
contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
)
self.context.set_tensor_address(
input_name, inputs_cpu.ctypes.data
)
else:
self.context.set_input_shape(
input_name, tuple(contiguous_inputs[i].shape)
)
if cudagraphs_enabled:
self._input_buffers[i].copy_(contiguous_inputs[i])
self.context.set_tensor_address(
input_name, self._input_buffers[i].data_ptr()
)
else:
self.context.set_tensor_address(
input_name, contiguous_inputs[i].data_ptr()
)
if shape_changed:
# Check if input shapes can be inferred.
uninferred_input_names = self.context.infer_shapes()
Expand Down
3 changes: 3 additions & 0 deletions tests/py/dynamo/runtime/test_002_cudagraphs_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
"Torch-TensorRT runtime is not available",
)
class TestCudagraphsCPP(TestCase):
def tearDown(self):
# Reset to default cuda graph mode after each test
torch_tensorrt.runtime.set_cudagraphs_mode(False)

def test_cudagraphs_on(self):
torch_tensorrt.runtime.set_cudagraphs_mode(True)
Expand Down
4 changes: 4 additions & 0 deletions tests/py/dynamo/runtime/test_002_cudagraphs_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@


class TestCudagraphsPython(TestCase):
def tearDown(self):
# Reset to default cuda graph mode after each test
torch_tensorrt.runtime.set_cudagraphs_mode(False)

def test_cudagraphs_on(self):
torch_tensorrt.runtime.set_cudagraphs_mode(True)
self.assertTrue(torch_tensorrt.runtime.get_cudagraphs_mode())
Expand Down

0 comments on commit f480353

Please sign in to comment.