-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[None][chore] Mass integration of release/1.0 - 2nd #7171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ee20f84
b7a7977
76736a1
eea0ebd
b2c953f
ccd36f4
60a944c
92e209c
8928405
771786f
0a4f757
f519b8c
be7c94f
e223cdb
34feef8
7ecbcc2
89ddff3
0b9c2ca
a0edae4
c959a07
ebe78d8
3b8c574
f4378c2
f64603e
ac36633
caa1897
7e98138
788adf2
7c1529b
e6b473a
c5fc171
3e29505
7040a83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,10 +1,5 @@ | ||||||||||||||||||||||||||
| # This file defines code ownership rules for the repository. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # The following rule should only be uncommented on release branches (e.g., release/0.19). | ||||||||||||||||||||||||||
| # The rule below requires that any PR to release/**/* branches must be approved by at least one member | ||||||||||||||||||||||||||
| # of the NVIDIA/trt-llm-release-branch-approval team, regardless of who else approves the PR. | ||||||||||||||||||||||||||
| # Without approval from a member of this team, PRs cannot be merged to release branches. | ||||||||||||||||||||||||||
| # * @NVIDIA/trt-llm-release-branch-approval | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| ## TensorRT-LLM Infra | ||||||||||||||||||||||||||
| ### CI | ||||||||||||||||||||||||||
|
|
@@ -160,3 +155,9 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers | |||||||||||||||||||||||||
| # from a member of this team, PRs affecting public APIs cannot be merged to main or release branches. | ||||||||||||||||||||||||||
| /tests/unittest/api_stability/ @NVIDIA/trt-llm-noncommitted-api-review-committee | ||||||||||||||||||||||||||
| /tests/unittest/api_stability/references_committed/ @NVIDIA/trt-llm-committed-api-review-committee | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # The following rule should only be uncommented on release branches (e.g., release/0.19). | ||||||||||||||||||||||||||
| # The rule below requires that any PR to release/**/* branches must be approved by at least one member | ||||||||||||||||||||||||||
| # of the NVIDIA/trt-llm-release-branch-approval team, regardless of who else approves the PR. | ||||||||||||||||||||||||||
| # Without approval from a member of this team, PRs cannot be merged to release branches. | ||||||||||||||||||||||||||
| # * @NVIDIA/trt-llm-release-branch-approval | ||||||||||||||||||||||||||
|
Comment on lines
+159
to
+163
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Clarify comment: enabling the catch‑all on release branches will override ALL specific owners (last‑match‑wins). If you need AND semantics (release approver in addition to path/API owners), CODEOWNERS can’t express it with a trailing -# The following rule should only be uncommented on release branches (e.g., release/0.19).
-# The rule below requires that any PR to release/**/* branches must be approved by at least one member
-# of the NVIDIA/trt-llm-release-branch-approval team, regardless of who else approves the PR.
-# Without approval from a member of this team, PRs cannot be merged to release branches.
+# Release-branch only — DO NOT uncomment on main.
+# Important: CODEOWNERS is last-match-wins. Uncommenting the catch-all below on a release branch
+# will override all specific owners above (including API review committees). If you need AND semantics
+# (release approver in addition to per-path owners), maintain a separate CODEOWNERS on release/**
+# that appends @NVIDIA/trt-llm-release-branch-approval to each path, or enforce via Rulesets “Required reviewers”.
+# Example for release/* (keep commented here):
# * @NVIDIA/trt-llm-release-branch-approval📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -392,8 +392,8 @@ class FusedMoeRunner : public torch::CustomClassHolder | |
| std::vector<int64_t> output_shape = {num_rows, unpadded_hidden_size_val}; | ||
| auto output = torch::empty(output_shape, input.options().dtype(mOutputDtype)); | ||
|
|
||
| WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total, | ||
| static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode); | ||
| WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total, | ||
| static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode, stream); | ||
|
|
||
| auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales); | ||
| kernels::MoeMinLatencyParams min_latency_params{}; | ||
|
|
@@ -553,8 +553,8 @@ class FusedMoeRunner : public torch::CustomClassHolder | |
| min_latency_params.experts_to_token_score = static_cast<float*>(experts_to_token_score.data_ptr()); | ||
| min_latency_params.active_expert_global_ids = static_cast<int*>(active_expert_global_ids.data_ptr()); | ||
|
|
||
| WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total, | ||
| static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode); | ||
| WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total, | ||
| static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode, stream); | ||
|
|
||
| auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales); | ||
|
|
||
|
|
@@ -709,6 +709,7 @@ class FusedMoeRunner : public torch::CustomClassHolder | |
| // e.g. 16 nvfp4 elements are packed into a single int64 element | ||
| int64_t mInnerDimMultiplier; | ||
| char* mProfileWorkspace = nullptr; | ||
| WorkspaceInfo workspace_info; | ||
|
|
||
| bool mUseDeepSeekFP8BlockScaling = false; | ||
| bool mUseW4GroupScaling = false; | ||
|
|
@@ -757,9 +758,9 @@ class FusedMoeRunner : public torch::CustomClassHolder | |
| mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile); | ||
| } | ||
|
|
||
| WorkspaceInfo getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, | ||
| WorkspaceInfo const& getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, | ||
| int num_experts, int experts_per_token, ActivationType activation_type, | ||
| kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode) | ||
| kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode, cudaStream_t stream) | ||
| { | ||
| size_t moe_workspace_size = mKernelRunner->getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts, | ||
| experts_per_token, activation_type, parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling, | ||
|
|
@@ -768,15 +769,29 @@ class FusedMoeRunner : public torch::CustomClassHolder | |
|
|
||
| std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size}; | ||
|
|
||
| size_t total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size()); | ||
| int64_t const total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size()); | ||
|
|
||
| WorkspaceInfo info{}; | ||
| info.workspace = torch::empty({static_cast<long>(total_workspace_size)}, | ||
| torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); | ||
| info.src_to_dest_map | ||
| = common::nextWorkspacePtr(static_cast<int8_t*>(info.workspace.data_ptr()), moe_workspace_size); | ||
| bool is_capturing = tensorrt_llm::common::isCapturing(stream); | ||
| // Always allocate workspace when capturing cuda graph to avoid illegal memory access during replay | ||
| if (is_capturing || workspace_info.workspace.numel() < total_workspace_size) | ||
| { | ||
| if (is_capturing) | ||
| { | ||
| TLLM_LOG_DEBUG( | ||
| "Allocating MoE workspace with %ld bytes size during cuda graph capture", total_workspace_size); | ||
| } | ||
| else | ||
| { | ||
| TLLM_LOG_DEBUG("MoE workspace size is not enough, increase the size from %ld bytes to %ld bytes", | ||
| workspace_info.workspace.numel(), total_workspace_size); | ||
| } | ||
| workspace_info.workspace = torch::empty({static_cast<long>(total_workspace_size)}, | ||
| torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); | ||
| } | ||
|
Comment on lines
+772
to
+790
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Use size_t for byte sizes and correct printf-specifiers; avoid narrowing to long.
Apply this diff: - int64_t const total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
+ size_t const total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
@@
- TLLM_LOG_DEBUG(
- "Allocating MoE workspace with %ld bytes size during cuda graph capture", total_workspace_size);
+ TLLM_LOG_DEBUG(
+ "Allocating MoE workspace with %zu bytes size during cuda graph capture", total_workspace_size);
@@
- TLLM_LOG_DEBUG("MoE workspace size is not enough, increase the size from %ld bytes to %ld bytes",
- workspace_info.workspace.numel(), total_workspace_size);
+ TLLM_LOG_DEBUG("MoE workspace size is not enough, increase the size from %lld bytes to %zu bytes",
+ static_cast<long long>(mWorkspaceInfo.workspace.numel()), total_workspace_size);
@@
- workspace_info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
+ mWorkspaceInfo.workspace = torch::empty({static_cast<int64_t>(total_workspace_size)},
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
🤖 Prompt for AI Agents |
||
| workspace_info.src_to_dest_map | ||
| = common::nextWorkspacePtr(static_cast<int8_t*>(workspace_info.workspace.data_ptr()), moe_workspace_size); | ||
|
|
||
| return info; | ||
| return workspace_info; | ||
| } | ||
|
|
||
| kernels::QuantParams getQuantParams(int64_t const num_experts_on_rank, int64_t const hidden_size, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -174,7 +174,8 @@ FROM wheel AS tritonbuild | |||||||||
| WORKDIR /src/tensorrt_llm | ||||||||||
| RUN pip install /src/tensorrt_llm/build/tensorrt_llm*.whl | ||||||||||
| COPY ./triton_backend/ ./triton_backend/ | ||||||||||
| RUN bash ./triton_backend/inflight_batcher_llm/scripts/build.sh | ||||||||||
| ARG TRITON_BASE_TAG | ||||||||||
| RUN bash ./triton_backend/inflight_batcher_llm/scripts/build.sh -s "r${TRITON_BASE_TAG%-py3}" | ||||||||||
|
|
||||||||||
|
Comment on lines
+177
to
179
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Pass TRITON_SHORT_TAG from TRITON_BASE_TAG: good; broaden suffix stripping. The current pattern strips only “-py3”. If we ever move to tags like “-py3.11”, the suffix won’t be removed. Use “-py3*”. Apply this diff: -RUN bash ./triton_backend/inflight_batcher_llm/scripts/build.sh -s "r${TRITON_BASE_TAG%-py3}"
+RUN bash ./triton_backend/inflight_batcher_llm/scripts/build.sh -s "r${TRITON_BASE_TAG%-py3*}"📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||
|
|
||||||||||
| FROM release AS tritonrelease | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| # LLM API with TensorRT Engine | ||
| A simple inference example with TinyLlama using the LLM API: | ||
|
|
||
| ```{literalinclude} ../../examples/llm-api/_tensorrt_engine/quickstart_example.py | ||
| :language: python | ||
| :linenos: | ||
| ``` | ||
|
Comment on lines
+4
to
+7
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainFix literalinclude relative path (currently off by one level). The code block includes the script via "../../examples/…", but this document lives under docs/source/legacy. To reach the repo-level examples directory, you need to go up three levels (not two). The link below (Line 9) already uses "../../../…", which further indicates the mismatch. Apply this diff: -```{literalinclude} ../../examples/llm-api/_tensorrt_engine/quickstart_example.py
+```{literalinclude} ../../../examples/llm-api/_tensorrt_engine/quickstart_example.py
:language: python
:linenos:Length of output: 11225 Let’s bypass the PATH issue and invoke Sphinx as a module: 🏁 Script executed: #!/bin/bash
# Strict docs build to fail on warnings (invoke via module to avoid PATH problems)
python -m sphinx -nW docs/source docs/_build/htmlLength of output: 2359 🏁 Script executed: #!/bin/bash
# Verify existence of example file via both relative paths from the docs file
BASE="docs/source/legacy"
INCORRECT="../../examples/llm-api/_tensorrt_engine/quickstart_example.py"
CORRECT="../../../examples/llm-api/_tensorrt_engine/quickstart_example.py"
echo "Checking incorrect path (../../examples/...):"
if [[ -f "${BASE}/${INCORRECT}" ]]; then
echo "FOUND: ${BASE}/${INCORRECT}"
else
echo "MISSING: ${BASE}/${INCORRECT}"
fi
echo "Checking corrected path (../../../examples/...):"
if [[ -f "${BASE}/${CORRECT}" ]]; then
echo "FOUND: ${BASE}/${CORRECT}"
else
echo "MISSING: ${BASE}/${CORRECT}"
fiLength of output: 949 Fix literalinclude path in docs/source/legacy/tensorrt_quickstart.md The current Apply this patch: --- docs/source/legacy/tensorrt_quickstart.md
@@ lines 4-7
-```{literalinclude} ../../examples/llm-api/_tensorrt_engine/quickstart_example.py
+```{literalinclude} ../../../examples/llm-api/_tensorrt_engine/quickstart_example.py
:language: python
:linenos:🤖 Prompt for AI Agents |
||
|
|
||
| For more advanced usage including distributed inference, multimodal, and speculative decoding, please refer to this [README](../../../examples/llm-api/README.md). | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainAvoid linking to files outside the docs tree (may break on published site). Relative link to ../../../examples/... likely won’t exist in the built site output. Prefer an absolute GitHub URL or move/duplicate the README content into the docs tree and link via :doc: or toctree. To validate locally (fail on warnings and catch include/link issues): 🏁 Script executed: #!/bin/bash
set -euo pipefail
python -m pip install -r docs/requirements.txt
# Build HTML with nitpicky warnings as errors
python -m sphinx -nW docs/source docs/_build/html
# Spot-check the target file exists relative to this doc (sanity)
BASE="docs/source/legacy"
TARGET="../../../examples/llm-api/README.md"
if [[ -f "${BASE}/${TARGET}" ]]; then
echo "Found file at ${BASE}/${TARGET} (but it still won't be copied into the site)."
else
echo "File not found at ${BASE}/${TARGET} — link will be broken in source as well."
fiLength of output: 13460 🏁 Script executed: #!/bin/bash
# Verify existence of the linked README relative to the docs tree
if [[ -f docs/source/legacy/../../../examples/llm-api/README.md ]]; then
echo "Found file at docs/source/legacy/../../../examples/llm-api/README.md"
else
echo "File not found at docs/source/legacy/../../../examples/llm-api/README.md"
fiLength of output: 212 Ensure documentation links remain within the built site Linking directly to To resolve, choose one of the following approaches:
Please update the link accordingly to prevent broken references in the published site. 🧰 Tools🪛 LanguageTool[grammar] ~9-~9: There might be a mistake here. (QB_NEW_EN) 🤖 Prompt for AI Agents |
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,33 @@ | ||||||||||||||||||||||||||||||||||||||||||||
| from tensorrt_llm import LLM, SamplingParams | ||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Add NVIDIA SPDX header per repo guidelines. Examples are source files and should carry the NVIDIA SPDX header. Apply this diff to prepend the header: +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+
from tensorrt_llm import LLM, SamplingParams📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def main(): | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| # Model could accept HF model name, a path to local HF model, | ||||||||||||||||||||||||||||||||||||||||||||
| # or TensorRT Model Optimizer's quantized checkpoints like nvidia/Llama-3.1-8B-Instruct-FP8 on HF. | ||||||||||||||||||||||||||||||||||||||||||||
| llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1
to
+9
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Import the TRT engine LLM and accept --engine_dir. -from tensorrt_llm import LLM, SamplingParams
+import argparse
+from tensorrt_llm._tensorrt_engine import LLM, SamplingParams
@@
def main():
-
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--engine_dir", type=str, default=None)
+ args = parser.parse_args()
@@
- llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
+ llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", engine_dir=args.engine_dir)📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||
| # Sample prompts. | ||||||||||||||||||||||||||||||||||||||||||||
| prompts = [ | ||||||||||||||||||||||||||||||||||||||||||||
| "Hello, my name is", | ||||||||||||||||||||||||||||||||||||||||||||
| "The capital of France is", | ||||||||||||||||||||||||||||||||||||||||||||
| "The future of AI is", | ||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| # Create a sampling params. | ||||||||||||||||||||||||||||||||||||||||||||
| sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| for output in llm.generate(prompts, sampling_params): | ||||||||||||||||||||||||||||||||||||||||||||
| print( | ||||||||||||||||||||||||||||||||||||||||||||
| f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}" | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| # Got output like | ||||||||||||||||||||||||||||||||||||||||||||
| # Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming' | ||||||||||||||||||||||||||||||||||||||||||||
| # Prompt: 'The president of the United States is', Generated text: 'likely to nominate a new Supreme Court justice to fill the seat vacated by the death of Antonin Scalia. The Senate should vote to confirm the' | ||||||||||||||||||||||||||||||||||||||||||||
| # Prompt: 'The capital of France is', Generated text: 'Paris.' | ||||||||||||||||||||||||||||||||||||||||||||
| # Prompt: 'The future of AI is', Generated text: 'an exciting time for us. We are constantly researching, developing, and improving our platform to create the most advanced and efficient model available. We are' | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| if __name__ == '__main__': | ||||||||||||||||||||||||||||||||||||||||||||
| main() | ||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,11 +1,17 @@ | ||||||||||||||
| from tensorrt_llm import LLM, SamplingParams | ||||||||||||||
| from tensorrt_llm import BuildConfig, SamplingParams | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Add NVIDIA SPDX header per repo guidelines. +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+
from tensorrt_llm import BuildConfig, SamplingParams📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||
| from tensorrt_llm._tensorrt_engine import LLM # NOTE the change | ||||||||||||||
|
Comment on lines
+1
to
+2
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Use public API import for LLM; avoid private module path End-user examples should import LLM from the public package namespace. Importing from Apply this diff: -from tensorrt_llm import BuildConfig, SamplingParams
-from tensorrt_llm._tensorrt_engine import LLM # NOTE the change
+from tensorrt_llm import BuildConfig, SamplingParams, LLM📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def main(): | ||||||||||||||
|
|
||||||||||||||
| build_config = BuildConfig() | ||||||||||||||
| build_config.max_batch_size = 256 | ||||||||||||||
| build_config.max_num_tokens = 1024 | ||||||||||||||
|
|
||||||||||||||
| # Model could accept HF model name, a path to local HF model, | ||||||||||||||
| # or TensorRT Model Optimizer's quantized checkpoints like nvidia/Llama-3.1-8B-Instruct-FP8 on HF. | ||||||||||||||
| llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") | ||||||||||||||
| llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | ||||||||||||||
| build_config=build_config) | ||||||||||||||
|
|
||||||||||||||
| # Sample prompts. | ||||||||||||||
| prompts = [ | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -122,6 +122,15 @@ def add_multimodal_args(parser): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| " ├── __init__.py" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| " ├── <model_name>.py" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| " └── <sub_dirs>")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Add multiturn conversation related parameters | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--multiturn", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| action="store_true", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| help="Enable multi-turn conversation mode.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "--conversation_turns", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| type=int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| default=2, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| help="Number of conversation turns for automated testing.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return parser | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -188,6 +197,80 @@ def main(): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"Unsupported model_type: {model_type} found!\n" \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"Supported types: {MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types()}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # If multiturn mode is enabled | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if args.multiturn: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Run predefined multiturn conversation examples | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert args.prompt is not None, "Please provide a prompt for multiturn conversation." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert args.media is not None, "Please provide media for multiturn conversation." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Determine how many turns to run | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_turns = min(args.conversation_turns, len(args.prompt)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generated_outputs = [] # Store generated outputs for return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Initialize conversation history with the first prompt | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| conversation_history = args.prompt[0] if args.prompt else "" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i in range(max_turns): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"\n--- Turn {i+1} ---") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Use multimodal input loader to process input with conversation context | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Use accumulated conversation history instead of just the current prompt | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cur_prompt = conversation_history | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inputs = default_multimodal_input_loader( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokenizer=llm.tokenizer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_dir=llm._hf_model_dir, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_type=model_type, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| modality=args.modality, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompts=[cur_prompt], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| media=args.media, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| image_data_format="pt", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_frames=8, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device="cpu") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+219
to
+229
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix: nested-media modalities will assert in loader; also honor CLI image_format/num_frames/device. default_multimodal_input_loader asserts when len(prompts)=1 and media is a list-of-lists (e.g., modality="image_audio"). Your current call passes prompts=[cur_prompt] and media=args.media unchanged, which will trip the assert for nested media. Additionally, the code ignores user CLI values and hardcodes "pt"/8/"cpu". Apply this refactor to (1) select a single sample for nested-media modalities, (2) pass user-specified format/frames/device, and (3) keep model_dir type consistent: - inputs = default_multimodal_input_loader(
- tokenizer=llm.tokenizer,
- model_dir=llm._hf_model_dir,
- model_type=model_type,
- modality=args.modality,
- prompts=[cur_prompt],
- media=args.media,
- image_data_format="pt",
- num_frames=8,
- device="cpu")
+ # For nested-media (e.g., image_audio = [ [img,aud], [img,aud], ... ]),
+ # pick one sample to pair with a single-turn prompt. For flat media
+ # (image/video/audio), 1 prompt + N media is supported by the loader.
+ media_for_turn = args.media
+ if isinstance(media_for_turn, list) and media_for_turn and isinstance(media_for_turn[0], list):
+ media_for_turn = [media_for_turn[0]]
+
+ inputs = default_multimodal_input_loader(
+ tokenizer=llm.tokenizer,
+ model_dir=str(llm._hf_model_dir),
+ model_type=model_type,
+ modality=args.modality,
+ prompts=[cur_prompt],
+ media=media_for_turn,
+ image_data_format=image_format,
+ num_frames=args.num_frames,
+ device=args.device)Follow-up: If you want to reuse the same nested media across turns, consider extracting the first sample once outside the loop and reusing it to avoid repeated conditionals. 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lora_request = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if args.load_lora: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if model_class is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "model_class must be provided when load_lora is True" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lora_request = model_class.lora_request( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| len(inputs), args.modality, llm._hf_model_dir) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Generate response | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs = llm.generate(inputs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sampling_params, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lora_request=lora_request) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert outputs and len( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs) > 0 and outputs[0].outputs and len( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs[0].outputs) > 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| response = outputs[0].outputs[0].text.strip() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Store generated output | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generated_outputs.append({ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "turn": i + 1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "user_input": cur_prompt, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "assistant_response": response, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "media": args.media | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| conversation_history = conversation_history + "\n" + response | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if i + 1 < len(args.prompt): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| conversation_history = conversation_history + "\n" + args.prompt[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| i + 1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"Error in turn {i+1}: {e}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import traceback | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| traceback.print_exc() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i, output in enumerate(generated_outputs): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"[{i}] Prompt: {output['user_input']!r}, Generated text: {output['assistant_response']!r}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Original single-turn processing logic | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # set prompts and media to example prompts and images if they are not provided | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if args.prompt is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args.prompt = example_medias_and_prompts[args.modality]["prompt"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should not be done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we need to find a way of preventing to merge this back into
main.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious how has this been handled all this while in previous MIs 🤔