diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 884c51e8ea..2b32e5e015 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -928,6 +928,30 @@ def aten_ops_slice( ) +@dynamo_tensorrt_converter(torch.ops.aten.chunk.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_chunk( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.slice.chunk( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args_bounds_check(args, 2, 0), + ) + + def refit_validator(node: Node, settings: CompilationSettings = None) -> bool: # cumsum op is not refitable if settings and settings.make_refittable: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 3274d78c2b..9b5c96b162 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -655,3 +655,58 @@ def nested( ) return reshape_output + + +def chunk( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + chunks: int, + dim: int, +) -> TRTTensor: + if chunks <= 0: + raise RuntimeError( + f"chunk expects `chunks` to be greater than 0, got: {chunks}" + ) + + shape = input.shape + dim = get_positive_dim(dim, len(shape)) + + if dim >= len(shape): + raise RuntimeError( + f"chunk expects `dim` to be less than the length of input shape, got: {dim}" + ) + + dynamic_shape = has_dynamic_shape(input.shape) + if dynamic_shape > 0: + # Check whether slice target dim is dynamic shape dim + assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + + size_dim = shape[dim] + chunk_size = math.ceil(size_dim / chunks) + result = [] + start = 0 + end = min(start + chunk_size, size_dim) + cnt = 0 + + while start < end: + result.append( + slice_op( + ctx, + target, + source_ir, + f"{name}_slice_{cnt}", + input, + dim, + start, + end, + 1, + ) + ) + start = end + end = min(start + chunk_size, size_dim) + cnt += 1 + + return result