You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
# isort: off
import torch
import tensorrt as trt
# isort: on
from polygraphy.backend.trt import (CreateConfig, EngineFromNetwork, Profile,
TrtRunner)
import tensorrt_llm
from tensorrt_llm import Tensor
class TestFunctional(unittest.TestCase):
def setUp(self):
tensorrt_llm.logger.set_level('error')
def test_expand_dyn_input(self):
# test data
dtype = 'float32'
input_shape = (-1, 10)
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
with tensorrt_llm.net_guard(net):
network = tensorrt_llm.default_trtnet()
input = Tensor(name='input',
shape=input_shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
shape = tensorrt_llm.functional.constant(np.array([1, 10]))
output = tensorrt_llm.functional.expand(input, shape)
output_shape = tensorrt_llm.functional.shape(output, 0)
# trt run
profiles = [Profile().add('shape', (1, 1), input_shape, (10, 10))]
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network),
config=CreateConfig(profiles=profiles))
def test_expand_dyn_inout(self):
# test data
dtype = 'float32'
input_shape = (-1, 10)
# construct trt network
builder = tensorrt_llm.Builder()
net = builder.create_network()
with tensorrt_llm.net_guard(net):
network = tensorrt_llm.default_trtnet()
input = Tensor(name='input',
shape=input_shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
shape = tensorrt_llm.functional.constant(np.array([-1, 10]))
output = tensorrt_llm.functional.expand(input, shape)
output_shape = tensorrt_llm.functional.shape(output, 0)
# trt run
profiles = [Profile().add('shape', (1, 1), input_shape, (10, 10))]
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network),
config=CreateConfig(profiles=profiles))
Rename the above file to test_expand_dyn.py and run:
python3 -m unittest test_expand_dyn.py
Expected behavior
The test should succeed.
actual behavior
The test fails with the following:
[TensorRT-LLM] TensorRT-LLM version: 0.10.0
[07/03/2024-22:05:51] [TRT] [E] ITensor::getDimensions: Error Code 4: Internal Error (Tensor (Unnamed Layer* 1) [Slice]_output has axis 0 with inherently negative length. Proven upper bound is -1. Network must have an instance where axis has non-negative length.)
[07/03/2024-22:05:51] [TRT] [E] ITensor::getDimensions: Error Code 4: Internal Error (Output shape can not be computed for node SHAPE_4.)
E.
======================================================================
ERROR: test_expand_dyn_inout (test_expand_dyn.TestFunctional)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/TensorRT-LLM/tests/functional/test_expand_dyn.py", line 72, in test_expand_dyn_inout
output_shape = tensorrt_llm.functional.shape(output, 0)
File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/functional.py", line 1779, in shape
return gather(res, dim=0, indices=dim).view([])
File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/functional.py", line 1831, in gather
assert input.rank() == indices.rank()
File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/functional.py", line 476, in rank
return len(self.trt_tensor.shape)
ValueError: __len__() should return >= 0
----------------------------------------------------------------------
Ran 2 tests in 1.425s
FAILED (errors=1)
additional notes
When compiling a llama-like model using trtllm-build, tensors with a batch size dimension have that dimension set to -1. But operators like expand and inherently slice seem to have issues supporting an input or output tensor with a dimension with size -1, so subsequent operations using the output fail to build.
The text was updated successfully, but these errors were encountered:
when the batch_size used as. input, I must use batch_size= grouped_features["batch_size"].item() to get, it will encounted anther bug
queries = query.unsqueeze(1).expand(batch_size, max_seq_length, -1)
System Info
AWS p3.8xlarge (https://aws.amazon.com/ec2/instance-types/p3/)
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
test_expand_dyn.py.txt
Rename the above file to test_expand_dyn.py and run:
python3 -m unittest test_expand_dyn.py
Expected behavior
The test should succeed.
actual behavior
The test fails with the following:
additional notes
When compiling a llama-like model using
trtllm-build
, tensors with a batch size dimension have that dimension set to -1. But operators likeexpand
and inherentlyslice
seem to have issues supporting an input or output tensor with a dimension with size -1, so subsequent operations using the output fail to build.The text was updated successfully, but these errors were encountered: