Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance Issue with inflight_batcher_llm Model in v0.13.0 #622

Open
junstar92 opened this issue Oct 18, 2024 · 6 comments
Open

Performance Issue with inflight_batcher_llm Model in v0.13.0 #622

junstar92 opened this issue Oct 18, 2024 · 6 comments

Comments

@junstar92
Copy link

junstar92 commented Oct 18, 2024

Description

I'm reporting a significant performance degradation when using the ensemble model in the inflight_batcher_llm provided in TensorRT-LLM Backend v0.13.0. During benchmarks in my environment, I observed that the performance dropped substantially.

After investigating, I found that the root cause of the issue lies within the postprocessing model in the ensemble, which is implemented in Python. Specifically, there is a performance bottleneck when calculating the tokenizer's vocabulary size on every execute() call.

if tokens[i] < len(self.tokenizer.vocab):

Before v0.13.0, this model used self.tokenizer.vocab_size to access the vocabulary size, which had no performance issues. However, starting from v0.13.0, this was changed to len(self.tokenizer.vocab). In the execute() method, the len(self.tokenizer.vocab) is called repeatedly, and this operation has been measured to take approximately 40ms per execution in my environment. This code returns a dictionary and calculates the length of this dictionary.

This issue becomes even more critical in streaming mode, where output tokens are generated one by one. In streaming scenarios, len(self.tokenizer.vocab) is called each time a token is generated, and the resulting overhead accumulates during the entire generation process. As more tokens are generated, the repeated invocation of it significantly slows down the model's performance.

Here is the simple Python code to measure the execution time of len(self.tokenizer.vocab):

from transformers import AutoTokenizer
import time

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
start = time.time()
vocab_size = len(tokenizer.vocab)
end = time.time()
print(f'elapsed time: {(end - start) * 1000:.5f} ms')
# print "elapsed time: 40.49420 ms" in my case

start = time.time()
vocab = tokenizer.vocab
end = time.time()
print(f'elapsed time: {(end - start) * 1000:.5f} ms / id: {id(vocab)}')
# print "elapsed time: 37.93049 ms / id: 140201663341888"
start = time.time()
vocab = tokenizer.vocab
end = time.time()
print(f'elapsed time: {(end - start) * 1000:.5f} ms / id: {id(vocab)}')
# print "elapsed time: 40.11583 ms / id: 140201663341824"

start = time.time()
vocab_size = len(vocab)
end = time.time()
print(f'elapsed time: {(end - start) * 1000:.5f} ms')
# print "elapsed time: 0.00095 ms"

Every call tokenizer.vocab creates a new dictionary object, which causing noticeable performance overhead.

Proposed Solution:

This issue can be addressed simply by pre-computing the total_vocal_size during initialize call, and reusing this pre-computed value in subsequent execute calls. After making this modification, the benchmark results returned to normal performance levels.

class TritonPythonModel:
    ...
    def initialize(self, args):
        ...
        self.total_vocab_size = len(self.tokenizer.vocab)

    ...
    def _postprocessing(self, tokens_batch, sequence_lengths):
        ...
                    if tokens[i] < self.total_vocab_size:
                        ...
        ...

Who can help?

@kaiyux @byshiue @schetlur-nv

@Alireza3242 Alireza3242 mentioned this issue Nov 9, 2024
4 tasks
@Columpio
Copy link

Looks like it is fixed by a new commit in main

@frosk1
Copy link

frosk1 commented Dec 18, 2024

Could this be the reason why I am seeing very slow performance serving a tensorrt-llm model with the backend using the ensemble model compared to benchmarks for a similar model deployed with NIM?

@frosk1
Copy link

frosk1 commented Dec 18, 2024

I am using the 24.11-trtllm-python-py3 image

@schetlur-nv
Copy link
Collaborator

@frosk1 can you share some more details of what you are running - like your build and run commands? There could be multiple factors at play here.
This bug should be fixed in that container you point to.

@frosk1
Copy link

frosk1 commented Dec 22, 2024

I will add a new issue for it

@frosk1
Copy link

frosk1 commented Dec 22, 2024

here:
#667

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants