Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand doesn't handle dynamic shaped tensors. #1887

Closed
1 of 4 tasks
jxchenus opened this issue Jul 3, 2024 · 3 comments
Closed
1 of 4 tasks

Expand doesn't handle dynamic shaped tensors. #1887

jxchenus opened this issue Jul 3, 2024 · 3 comments
Labels
not a bug Some known limitation, but not a bug. others

Comments

@jxchenus
Copy link

jxchenus commented Jul 3, 2024

System Info

AWS p3.8xlarge (https://aws.amazon.com/ec2/instance-types/p3/)

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

test_expand_dyn.py.txt

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

import numpy as np

# isort: off
import torch
import tensorrt as trt
# isort: on
from polygraphy.backend.trt import (CreateConfig, EngineFromNetwork, Profile,
                                    TrtRunner)

import tensorrt_llm
from tensorrt_llm import Tensor


class TestFunctional(unittest.TestCase):

    def setUp(self):
        tensorrt_llm.logger.set_level('error')

    def test_expand_dyn_input(self):
        # test data
        dtype = 'float32'
        input_shape = (-1, 10)

        # construct trt network
        builder = tensorrt_llm.Builder()
        net = builder.create_network()
        with tensorrt_llm.net_guard(net):
            network = tensorrt_llm.default_trtnet()
            input = Tensor(name='input',
                           shape=input_shape,
                           dtype=tensorrt_llm.str_dtype_to_trt(dtype))
            shape = tensorrt_llm.functional.constant(np.array([1, 10]))
            output = tensorrt_llm.functional.expand(input, shape)
            output_shape = tensorrt_llm.functional.shape(output, 0)

        # trt run
        profiles = [Profile().add('shape', (1, 1), input_shape, (10, 10))]
        build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network),
                                         config=CreateConfig(profiles=profiles))

    def test_expand_dyn_inout(self):
        # test data
        dtype = 'float32'
        input_shape = (-1, 10)

        # construct trt network
        builder = tensorrt_llm.Builder()
        net = builder.create_network()
        with tensorrt_llm.net_guard(net):
            network = tensorrt_llm.default_trtnet()
            input = Tensor(name='input',
                           shape=input_shape,
                           dtype=tensorrt_llm.str_dtype_to_trt(dtype))
            shape = tensorrt_llm.functional.constant(np.array([-1, 10]))
            output = tensorrt_llm.functional.expand(input, shape)
            output_shape = tensorrt_llm.functional.shape(output, 0)

        # trt run
        profiles = [Profile().add('shape', (1, 1), input_shape, (10, 10))]
        build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network),
                                         config=CreateConfig(profiles=profiles))

Rename the above file to test_expand_dyn.py and run:

python3 -m unittest test_expand_dyn.py

Expected behavior

The test should succeed.

actual behavior

The test fails with the following:

[TensorRT-LLM] TensorRT-LLM version: 0.10.0
[07/03/2024-22:05:51] [TRT] [E] ITensor::getDimensions: Error Code 4: Internal Error (Tensor (Unnamed Layer* 1) [Slice]_output has axis 0 with inherently negative length. Proven upper bound is -1. Network must have an instance where axis has non-negative length.)
[07/03/2024-22:05:51] [TRT] [E] ITensor::getDimensions: Error Code 4: Internal Error (Output shape can not be computed for node SHAPE_4.)
E.
======================================================================
ERROR: test_expand_dyn_inout (test_expand_dyn.TestFunctional)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/TensorRT-LLM/tests/functional/test_expand_dyn.py", line 72, in test_expand_dyn_inout
    output_shape = tensorrt_llm.functional.shape(output, 0)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/functional.py", line 1779, in shape
    return gather(res, dim=0, indices=dim).view([])
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/functional.py", line 1831, in gather
    assert input.rank() == indices.rank()
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/functional.py", line 476, in rank
    return len(self.trt_tensor.shape)
ValueError: __len__() should return >= 0

----------------------------------------------------------------------
Ran 2 tests in 1.425s

FAILED (errors=1)

additional notes

When compiling a llama-like model using trtllm-build, tensors with a batch size dimension have that dimension set to -1. But operators like expand and inherently slice seem to have issues supporting an input or output tensor with a dimension with size -1, so subsequent operations using the output fail to build.

@jxchenus
Copy link
Author

jxchenus commented Jul 8, 2024

I found a proper way to do it now: Basically the dynamic shape itself should be a model input.

@jxchenus jxchenus closed this as completed Jul 8, 2024
@QiJune QiJune added not a bug Some known limitation, but not a bug. others and removed bug Something isn't working Investigating labels Aug 5, 2024
@yjjinjie
Copy link

@jxchenus hello,I meet the same problem. how do you use the dynamic shape as a model_input?

pytorch/TensorRT#3140

when the batch_size used as. input, I must use batch_size= grouped_features["batch_size"].item() to get, it will encounted anther bug
queries = query.unsqueeze(1).expand(batch_size, max_seq_length, -1)

@jxchenus
Copy link
Author

I found a workaround by using slice fill mode instead. The shape of the output can be dynamically passed in:

NVIDIA/TensorRT#3979 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
not a bug Some known limitation, but not a bug. others
Projects
None yet
Development

No branches or pull requests

3 participants