-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance Issue with inflight_batcher_llm Model in v0.13.0 #622
Comments
Looks like it is fixed by a new commit in main |
Could this be the reason why I am seeing very slow performance serving a tensorrt-llm model with the backend using the ensemble model compared to benchmarks for a similar model deployed with NIM? |
I am using the 24.11-trtllm-python-py3 image |
@frosk1 can you share some more details of what you are running - like your build and run commands? There could be multiple factors at play here. |
I will add a new issue for it |
here: |
Description
I'm reporting a significant performance degradation when using the
ensemble
model in theinflight_batcher_llm
provided in TensorRT-LLM Backend v0.13.0. During benchmarks in my environment, I observed that the performance dropped substantially.After investigating, I found that the root cause of the issue lies within the postprocessing model in the
ensemble
, which is implemented in Python. Specifically, there is a performance bottleneck when calculating the tokenizer's vocabulary size on everyexecute()
call.tensorrtllm_backend/all_models/inflight_batcher_llm/postprocessing/1/model.py
Line 243 in f80395e
Before v0.13.0, this model used
self.tokenizer.vocab_size
to access the vocabulary size, which had no performance issues. However, starting from v0.13.0, this was changed tolen(self.tokenizer.vocab)
. In theexecute()
method, thelen(self.tokenizer.vocab)
is called repeatedly, and this operation has been measured to take approximately 40ms per execution in my environment. This code returns a dictionary and calculates the length of this dictionary.This issue becomes even more critical in streaming mode, where output tokens are generated one by one. In streaming scenarios,
len(self.tokenizer.vocab)
is called each time a token is generated, and the resulting overhead accumulates during the entire generation process. As more tokens are generated, the repeated invocation of it significantly slows down the model's performance.Here is the simple Python code to measure the execution time of
len(self.tokenizer.vocab)
:Every call
tokenizer.vocab
creates a new dictionary object, which causing noticeable performance overhead.Proposed Solution:
This issue can be addressed simply by pre-computing the
total_vocal_size
duringinitialize
call, and reusing this pre-computed value in subsequentexecute
calls. After making this modification, the benchmark results returned to normal performance levels.Who can help?
@kaiyux @byshiue @schetlur-nv
The text was updated successfully, but these errors were encountered: