-
Notifications
You must be signed in to change notification settings - Fork 351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: log_softmax decomposition #3137
Conversation
aten._log_softmax
dynamo converteraten._log_softmax
dynamo converter
@HolyWu Is the decomposition slower than the converter? |
@peri044 Yes, it's slower. Benchmarked with: import timeit
import numpy as np
import torch
import torch_tensorrt
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.m = torch.nn.LogSoftmax(dim=1)
def forward(self, x):
return self.m(x)
@torch.inference_mode()
def benchmark(model, inputs):
# Warm up
for _ in range(3):
model(*inputs)
torch.cuda.synchronize()
timings = []
for _ in range(100):
start_time = timeit.default_timer()
model(*inputs)
torch.cuda.synchronize()
end_time = timeit.default_timer()
timings.append(end_time - start_time)
return np.array(timings)
torch.manual_seed(12345)
device = torch.device("cuda", 0)
model = MyModule().eval().to(device).half()
inputs = (torch.randn((1, 1000000), dtype=torch.half, device=device),)
trt_model = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=inputs,
enabled_precisions={torch.half},
debug=True,
min_block_size=1,
device=device,
cache_built_engines=False,
reuse_cached_engines=False,
)
torch_timings = benchmark(model, inputs)
trt_timings = benchmark(trt_model, inputs)
print("")
print("Torch:")
print(f"\tMin={torch_timings.min()}, Mean={torch_timings.mean()}, Max={torch_timings.max()}")
print("")
print("TRT:")
print(f"\tMin={trt_timings.min()}, Mean={trt_timings.mean()}, Max={trt_timings.max()}")
torch.testing.assert_close(model(*inputs), trt_model(*inputs), rtol=5e-3, atol=5e-3) Before patchDEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
%x : [num_users=1] = placeholder[target=x]
%log_softmax : [num_users=1] = call_function[target=torch.ops.aten.log_softmax.int](args = (%x, 1), kwargs = {})
return (log_softmax,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
%x : [num_users=1] = placeholder[target=x]
%_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.float32})
%amax : [num_users=1] = call_function[target=torch.ops.aten.amax.default](args = (%_to_copy, [1], True), kwargs = {})
%sub : [num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%_to_copy, %amax), kwargs = {})
%exp : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub,), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%exp, [1], True), kwargs = {})
%log : [num_users=1] = call_function[target=torch.ops.aten.log.default](args = (%sum_1,), kwargs = {})
%sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub, %log), kwargs = {})
%_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%sub_1,), kwargs = {dtype: torch.float16})
return (_to_copy_1,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
%x : [num_users=1] = placeholder[target=x]
%_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.float32})
%amax : [num_users=1] = call_function[target=torch.ops.aten.amax.default](args = (%_to_copy, [1], True), kwargs = {})
%sub : [num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%_to_copy, %amax), kwargs = {})
%exp : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub,), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%exp, [1], True), kwargs = {})
%log : [num_users=1] = call_function[target=torch.ops.aten.log.default](args = (%sum_1,), kwargs = {})
%sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub, %log), kwargs = {})
%_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%sub_1,), kwargs = {dtype: torch.float16})
return (_to_copy_1,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
%x : [num_users=1] = placeholder[target=x]
%_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.float32})
%amax : [num_users=1] = call_function[target=torch.ops.aten.amax.default](args = (%_to_copy, [1], True), kwargs = {})
%sub : [num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%_to_copy, %amax), kwargs = {})
%exp : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub,), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%exp, [1], True), kwargs = {})
%log : [num_users=1] = call_function[target=torch.ops.aten.log.default](args = (%sum_1,), kwargs = {})
%sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub, %log), kwargs = {})
%_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%sub_1,), kwargs = {dtype: torch.float16})
return (_to_copy_1,)
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
%x : [num_users=1] = placeholder[target=x]
%_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.float32})
%amax : [num_users=1] = call_function[target=torch.ops.aten.amax.default](args = (%_to_copy, [1], True), kwargs = {})
%sub : [num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%_to_copy, %amax), kwargs = {})
%exp : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub,), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%exp, [1], True), kwargs = {})
%log : [num_users=1] = call_function[target=torch.ops.aten.log.default](args = (%sum_1,), kwargs = {})
%sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub, %log), kwargs = {})
%_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%sub_1,), kwargs = {dtype: torch.float16})
return (_to_copy_1,)
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten._to_copy.default + Operator Count: 2
- torch.ops.aten.amax.default + Operator Count: 1
- torch.ops.aten.sub.Tensor + Operator Count: 2
- torch.ops.aten.exp.default + Operator Count: 1
- torch.ops.aten.sum.dim_IntList + Operator Count: 1
- torch.ops.aten.log.default + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 8 operators out of 8 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten._to_copy.default + Operator Count: 2
- torch.ops.aten.amax.default + Operator Count: 1
- torch.ops.aten.sub.Tensor + Operator Count: 2
- torch.ops.aten.exp.default + Operator Count: 1
- torch.ops.aten.sum.dim_IntList + Operator Count: 1
- torch.ops.aten.log.default + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
Input shapes: [(1, 1000000)]
graph():
%x : [num_users=1] = placeholder[target=x]
%_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.float32})
%amax : [num_users=1] = call_function[target=torch.ops.aten.amax.default](args = (%_to_copy, [1], True), kwargs = {})
%sub : [num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%_to_copy, %amax), kwargs = {})
%exp : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub,), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%exp, [1], True), kwargs = {})
%log : [num_users=1] = call_function[target=torch.ops.aten.log.default](args = (%sum_1,), kwargs = {})
%sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub, %log), kwargs = {})
%_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%sub_1,), kwargs = {dtype: torch.float16})
return _to_copy_1
WARNING:torch_tensorrt.dynamo.utils:Could not detect the device on which the model exists. Assuming the model is on CPU
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +1, GPU +0, now: CPU 12070, GPU 1019 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +3074, GPU +384, now: CPU 15445, GPU 1403 (MiB)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[1, 1000000], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (1, 1000000)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/_to_copy (kind: aten._to_copy.default, args: ('x <Node>',))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/_to_copy [aten._to_copy.default] (Inputs: (x: (1, 1000000)@torch.float16) | Outputs: (_to_copy: (1, 1000000)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/amax (kind: aten.amax.default, args: ('_to_copy <Node>', ['1 <int>'], 'True <bool>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/amax [aten.amax.default] (Inputs: (_to_copy: (1, 1000000)@torch.float32, [1], True) | Outputs: (amax: (1, 1)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/sub (kind: aten.sub.Tensor, args: ('_to_copy <Node>', 'amax <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/sub [aten.sub.Tensor] (Inputs: (_to_copy: (1, 1000000)@torch.float32, amax: (1, 1)@torch.float32) | Outputs: (sub: (1, 1000000)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/exp (kind: aten.exp.default, args: ('sub <Node>',))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/exp [aten.exp.default] (Inputs: (sub: (1, 1000000)@torch.float32) | Outputs: (exp: (1, 1000000)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/sum_1 (kind: aten.sum.dim_IntList, args: ('exp <Node>', ['1 <int>'], 'True <bool>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/sum_1 [aten.sum.dim_IntList] (Inputs: (exp: (1, 1000000)@torch.float32, [1], True) | Outputs: (sum_1: (1, 1)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/log (kind: aten.log.default, args: ('sum_1 <Node>',))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/log [aten.log.default] (Inputs: (sum_1: (1, 1)@torch.float32) | Outputs: (log: (1, 1)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/sub_1 (kind: aten.sub.Tensor, args: ('sub <Node>', 'log <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/sub_1 [aten.sub.Tensor] (Inputs: (sub: (1, 1000000)@torch.float32, log: (1, 1)@torch.float32) | Outputs: (sub_1: (1, 1000000)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/_to_copy_1 (kind: aten._to_copy.default, args: ('sub_1 <Node>',))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/_to_copy_1 [aten._to_copy.default] (Inputs: (sub_1: (1, 1000000)@torch.float32) | Outputs: (_to_copy_1: (1, 1000000)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('_to_copy_1 <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(1, 1000000), dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (_to_copy_1: (1, 1000000)@torch.float16) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.004882
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
0:00:00 Building engine: ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━ 50% (3/6)
0:00:00 Building engine from subgraph: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0% (0/1) INFO:torch_tensorrt [Tenso 0:00:00 Building engine: ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━ 50% (3/6)
0:00:00 Building engine from subgraph: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0% (0/1)
0:00:00 Building engine: ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━ 50% (3/6)
0:00:00 Building engine from subgraph: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0% (0/1)
0:00:00 Computing profile costs: ━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━ 60% (3/5)
0:00:00 Determining refit engines: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━ 83% (5/6)INFO:torch_tensorrt [TensorRT Conversion Context]:Total Host Persistent Memory: 704
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Device Persistent Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Scratch Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Started assigning block shifts. This will take 5 steps to complete.
0:00:00 Building engine: ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━ 50% (3/6)
0:00:00 Building engine from subgraph: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0% (0/1)
0:00:00 Computing profile costs: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━ 80% (4/5)INFO:torch_tensorrt [Tenso 0:00:00 Building engine: ━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━ 67% (4/6)INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 0 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 4154 MiB
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.366514
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 21356 bytes of Memory
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 11384 bytes of code generator cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 12978937 bytes of compilation cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 18127 timing cache entries
ERROR: [Torch-TensorRT] - Platform constructor: windows_x86_64
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++
The graph consists of 8 Total Operators, of which 8 operators are supported, 100.0% coverage
Compiled with: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False)
Graph Structure:
Inputs: List[Tensor: (1, 1000000)@float16]
...
TRT Engine #1 - Submodule name: _run_on_acc_0
Engine Inputs: List[Tensor: (1, 1000000)@float16]
Number of Operators in Engine: 8
Engine Outputs: List[Tensor: (1, 1000000)@float16]
...
Outputs: List[Tensor: (1, 1000000)@float16]
------------------------- Aggregate Stats -------------------------
Average Number of Operators per TRT Engine: 8.0
Most Operators in a TRT Engine: 8
********** Recommendations **********
- For minimal graph segmentation, select min_block_size=8 which would generate 1 TRT engine(s)
- The current level of graph segmentation is equivalent to selecting min_block_size=8 which generates 1 TRT engine(s)
WARNING: [Torch-TensorRT] - Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
Torch:
Min=0.00018600000112201087, Mean=0.00020198499998514308, Max=0.000484800000776886
TRT:
Min=0.005296899998938898, Mean=0.005410501000078512, Max=0.00578259999929287 After patchDEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
%x : [num_users=1] = placeholder[target=x]
%log_softmax : [num_users=1] = call_function[target=torch.ops.aten.log_softmax.int](args = (%x, 1), kwargs = {})
return (log_softmax,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
%x : [num_users=1] = placeholder[target=x]
%_log_softmax : [num_users=1] = call_function[target=torch.ops.aten._log_softmax.default](args = (%x, 1, False), kwargs = {})
return (_log_softmax,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
%x : [num_users=1] = placeholder[target=x]
%_log_softmax : [num_users=1] = call_function[target=torch.ops.aten._log_softmax.default](args = (%x, 1, False), kwargs = {})
return (_log_softmax,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
%x : [num_users=1] = placeholder[target=x]
%_log_softmax : [num_users=1] = call_function[target=torch.ops.aten._log_softmax.default](args = (%x, 1, False), kwargs = {})
return (_log_softmax,)
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
%x : [num_users=1] = placeholder[target=x]
%_log_softmax : [num_users=1] = call_function[target=torch.ops.aten._log_softmax.default](args = (%x, 1, False), kwargs = {})
return (_log_softmax,)
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten._log_softmax.default + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten._log_softmax.default + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
Input shapes: [(1, 1000000)]
graph():
%x : [num_users=1] = placeholder[target=x]
%_log_softmax : [num_users=1] = call_function[target=torch.ops.aten._log_softmax.default](args = (%x, 1, False), kwargs = {})
return _log_softmax
WARNING:torch_tensorrt.dynamo.utils:Could not detect the device on which the model exists. Assuming the model is on CPU
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +3, GPU +0, now: CPU 11764, GPU 1019 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +3065, GPU +384, now: CPU 15152, GPU 1403 (MiB)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[1, 1000000], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (1, 1000000)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/_log_softmax (kind: aten._log_softmax.default, args: ('x <Node>', '1 <int>', 'False <bool>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/_log_softmax [aten._log_softmax.default] (Inputs: (x: (1, 1000000)@torch.float16, 1, False) | Outputs: (_log_softmax: (1, 1000000)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('_log_softmax <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(1, 1000000), dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (_log_softmax: (1, 1000000)@torch.float16) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.001952
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
0:00:00 Building engine: ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━ 50% (3/6)
0:00:00 Building engine from subgraph: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0% (0/1) INFO:torch_tensorrt [Tenso 0:00:00 Building engine: ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━ 50% (3/6)
0:00:00 Building engine from subgraph: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0% (0/1)
0:00:00 Building engine: ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━ 50% (3/6)
0:00:00 Building engine from subgraph: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0% (0/1)
0:00:00 Computing profile costs: ━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━ 60% (3/5)
0:00:00 Determining refit engines: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0% (0/1) INFO:torch_tensorrt [TensorRT Conversion Context]:Total Host Persistent Memory: 1840
0:00:00 Building engine: ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━ 50% (3/6)
0:00:00 Building engine from subgraph: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0% (0/1)
0:00:00 Computing profile costs: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━ 80% (4/5)INFO:torch_tensorrt [Tenso 0:00:00 Building engine: ━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━ 67% (4/6)INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 0 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 4150 MiB
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.168043
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 59020 bytes of Memory
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 11384 bytes of code generator cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 12955309 bytes of compilation cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 18097 timing cache entries
ERROR: [Torch-TensorRT] - Platform constructor: windows_x86_64
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++
The graph consists of 1 Total Operators, of which 1 operators are supported, 100.0% coverage
Compiled with: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False)
Graph Structure:
Inputs: List[Tensor: (1, 1000000)@float16]
...
TRT Engine #1 - Submodule name: _run_on_acc_0
Engine Inputs: List[Tensor: (1, 1000000)@float16]
Number of Operators in Engine: 1
Engine Outputs: List[Tensor: (1, 1000000)@float16]
...
Outputs: List[Tensor: (1, 1000000)@float16]
------------------------- Aggregate Stats -------------------------
Average Number of Operators per TRT Engine: 1.0
Most Operators in a TRT Engine: 1
********** Recommendations **********
- For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s)
- The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s)
WARNING: [Torch-TensorRT] - Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
Torch:
Min=0.00017979999938688707, Mean=0.0001965939999627153, Max=0.00043619999996735714
TRT:
Min=0.00292390000049636, Mean=0.003028965000012249, Max=0.003767100000004575 |
What happens if you just write a new decomposition that just inserts softmax + log? |
Actually seems like the PYT decomp pass is mostly just decomposing softmax:
Is this the case for sotfmax as well? Maybe all we need to do is disable the softmax decomp? |
It also works and achieve similar performance compared to converter. DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
%x : [num_users=1] = placeholder[target=x]
%log_softmax : [num_users=1] = call_function[target=torch.ops.aten.log_softmax.int](args = (%x, 1), kwargs = {})
return (log_softmax,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
%x : [num_users=1] = placeholder[target=x]
%_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%x, 1, False), kwargs = {})
%log : [num_users=1] = call_function[target=torch.ops.aten.log.default](args = (%_softmax,), kwargs = {})
return (log,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
%x : [num_users=1] = placeholder[target=x]
%_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%x, 1, False), kwargs = {})
%log : [num_users=1] = call_function[target=torch.ops.aten.log.default](args = (%_softmax,), kwargs = {})
return (log,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
%x : [num_users=1] = placeholder[target=x]
%_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%x, 1, False), kwargs = {})
%log : [num_users=1] = call_function[target=torch.ops.aten.log.default](args = (%_softmax,), kwargs = {})
return (log,)
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
%x : [num_users=1] = placeholder[target=x]
%_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%x, 1, False), kwargs = {})
%log : [num_users=1] = call_function[target=torch.ops.aten.log.default](args = (%_softmax,), kwargs = {})
return (log,)
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten._softmax.default + Operator Count: 1
- torch.ops.aten.log.default + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 2 operators out of 2 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten._softmax.default + Operator Count: 1
- torch.ops.aten.log.default + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
Input shapes: [(1, 1000000)]
graph():
%x : [num_users=1] = placeholder[target=x]
%_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%x, 1, False), kwargs = {})
%log : [num_users=1] = call_function[target=torch.ops.aten.log.default](args = (%_softmax,), kwargs = {})
return log
WARNING:torch_tensorrt.dynamo.utils:Could not detect the device on which the model exists. Assuming the model is on CPU
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +1, GPU +0, now: CPU 11848, GPU 1019 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +3072, GPU +384, now: CPU 15220, GPU 1403 (MiB)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[1, 1000000], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (1, 1000000)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/_softmax (kind: aten._softmax.default, args: ('x <Node>', '1 <int>', 'False <bool>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/_softmax [aten._softmax.default] (Inputs: (x: (1, 1000000)@torch.float16, 1, False) | Outputs: (_softmax: (1, 1000000)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/log (kind: aten.log.default, args: ('_softmax <Node>',))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/log [aten.log.default] (Inputs: (_softmax: (1, 1000000)@torch.float16) | Outputs: (log: (1, 1000000)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('log <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(1, 1000000), dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (log: (1, 1000000)@torch.float16) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.002930
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
0:00:00 Building engine: ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━ 50% (3/6)
0:00:00 Building engine from subgraph: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0% (0/1) INFO:torch_tensorrt [Tenso 0:00:00 Building engine: ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━ 50% (3/6)
0:00:00 Building engine from subgraph: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0% (0/1)
0:00:00 Building engine: ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━ 50% (3/6)
0:00:00 Building engine from subgraph: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0% (0/1)
0:00:00 Computing profile costs: ━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━ 60% (3/5)
0:00:00 Determining refit engines: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0% (0/1) INFO:torch_tensorrt [TensorRT Conversion Context]:Total Host Persistent Memory: 1840
0:00:00 Building engine: ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━ 50% (3/6)
0:00:00 Building engine from subgraph: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0% (0/1)
0:00:00 Computing profile costs: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━ 80% (4/5)INFO:torch_tensorrt [Tenso 0:00:00 Building engine: ━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━ 67% (4/6)INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 0 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 4154 MiB
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.164269
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 58996 bytes of Memory
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 11384 bytes of code generator cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 12978937 bytes of compilation cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 18125 timing cache entries
ERROR: [Torch-TensorRT] - Platform constructor: windows_x86_64
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++
The graph consists of 2 Total Operators, of which 2 operators are supported, 100.0% coverage
Compiled with: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False)
Graph Structure:
Inputs: List[Tensor: (1, 1000000)@float16]
...
TRT Engine #1 - Submodule name: _run_on_acc_0
Engine Inputs: List[Tensor: (1, 1000000)@float16]
Number of Operators in Engine: 2
Engine Outputs: List[Tensor: (1, 1000000)@float16]
...
Outputs: List[Tensor: (1, 1000000)@float16]
------------------------- Aggregate Stats -------------------------
Average Number of Operators per TRT Engine: 2.0
Most Operators in a TRT Engine: 2
********** Recommendations **********
- For minimal graph segmentation, select min_block_size=2 which would generate 1 TRT engine(s)
- The current level of graph segmentation is equivalent to selecting min_block_size=2 which generates 1 TRT engine(s)
WARNING: [Torch-TensorRT] - Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
Torch:
Min=0.00018660000023373868, Mean=0.0001979710000068735, Max=0.00045529999988502823
TRT:
Min=0.0029398000006040093, Mean=0.0030551729999569945, Max=0.0036590999998225016 |
PyTorch's |
I think we would rather have a decomposition than a converter. So I think we can just disable PYT log_softmax decompose and insert the one you wrote |
aten._log_softmax
dynamo converter
@peri044 @narendasan Please review again. |
Added handling of The culprit is TensorRT/py/torch_tensorrt/dynamo/utils.py Line 193 in e4e4d31
The testcases passed if I changed the above line to |
This line was added in #3002, but |
I think it should be fine to alter the implementation of |
I discussed this with @apbose and we are now changing the device from cpu to the default device (cuda:0). So, this PR #3008 should address that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me at this point, I think this is a cleaner and more comprehensive! Just one open question for @peri044 but other than that im pretty happy with this
This reverts commit 77993db.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes # (issue)
Type of change
Checklist: