-
Notifications
You must be signed in to change notification settings - Fork 364
feat: Support weight streaming #3111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2e8b563
79c3c3a
b81949e
9db1561
8ab068d
af79f76
d2bda5e
46a8ea5
4df7ea5
7356fcb
8d06793
eb7cad6
82f6528
0bcd264
2439df9
665089d
5a59471
fa407bc
c12f76f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
""" | ||
.. _weight_streaming_example: | ||
|
||
Weight Streaming | ||
======================= | ||
|
||
Weight streaming in TensorRT is a powerful feature designed to overcome GPU memory limitations | ||
when working with large models. It enables running models larger than available GPU memory | ||
by streaming weight data from host (CPU) memory to GPU memory during inference. | ||
|
||
Streaming larger amounts of memory will likely result in lower performance. But if | ||
streaming weights allows the user to run larger batch sizes and it can lead to higher throughput. | ||
This increased throughput can sometimes outweigh the slowdown caused by streaming weights. | ||
The optimal amount of memory to stream varies depending on the specific model and hardware. | ||
Experimenting with different memory limits can help find the best balance between streaming | ||
overhead and batch size benefits. | ||
|
||
This example uses a pre-trained Llama-2 model and show how to use weight streaming feature with | ||
Torch-TensorRT. | ||
1. compile option - build trt engine with weight streaming feature | ||
2. runtime api - weight streaming budget control by context manager | ||
""" | ||
|
||
# %% | ||
# Imports and Model Definition | ||
# ---------------------------------- | ||
|
||
import copy | ||
import timeit | ||
|
||
import numpy as np | ||
import torch | ||
import torch_tensorrt | ||
from transformers import AutoModelForCausalLM | ||
from utils import export_llm | ||
|
||
|
||
def time_generate(model, inputs, output_seq_length, iterations=10): | ||
""" | ||
Measure the time for generating a sentence over certain number of iterations | ||
""" | ||
# We only support single input (B x seq_len) for LLMs now | ||
input_seq = inputs[0] | ||
with torch.no_grad(): | ||
timings = [] | ||
for _ in range(iterations): | ||
start_time = timeit.default_timer() | ||
inputs_copy = copy.copy(input_seq) | ||
# Greedy decoding of the model. This generates up to max_tokens. | ||
while inputs_copy.shape[1] <= output_seq_length: | ||
outputs = model(inputs_copy) | ||
logits = outputs.logits | ||
next_token_logits = logits[:, -1, :] | ||
next_tokens = torch.argmax(next_token_logits, dim=-1) | ||
inputs_copy = torch.cat([inputs_copy, next_tokens[:, None]], dim=-1) | ||
torch.cuda.synchronize() | ||
end_time = timeit.default_timer() | ||
timings.append(end_time - start_time) | ||
|
||
times = np.array(timings) | ||
time_mean_ms = np.mean(times) * 1000 | ||
|
||
return time_mean_ms | ||
|
||
|
||
# Load the LLaMA-2 model | ||
DEVICE = torch.device("cuda:0") | ||
llama_path = "meta-llama/Llama-2-7b-chat-hf" | ||
with torch.no_grad(): | ||
model = AutoModelForCausalLM.from_pretrained( | ||
llama_path, use_cache=False, attn_implementation="eager" | ||
).eval() | ||
|
||
# Set input and output sequence lengths | ||
isl = 128 | ||
osl = 256 | ||
|
||
# Create random input tensors | ||
input_tensors = [torch.randint(0, 5, (1, isl), dtype=torch.int64).cuda()] | ||
# Convert the model to half precision (FP16) | ||
model = model.half() | ||
# Exports the LLM model into an ExportedProgram with dynamic shapes. | ||
llama2_ep = export_llm(model, input_tensors[0], max_seq_len=osl) | ||
|
||
# %% | ||
# Compiler option | ||
# ---------------------------------- | ||
# | ||
# enable_weight_streaming=True option and use_explicit_typing=True are required to build | ||
# the engine with weight streaming feature. use_explicit_typing=True option creates a | ||
# `strongly typed network <https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#strongly-typed-networks>`_ and only float32 precision is allowed in enabled_precisions option | ||
# | ||
|
||
# Create a TensorRT-compiled model | ||
trt_model = torch_tensorrt.dynamo.compile( | ||
llama2_ep, | ||
inputs=input_tensors, | ||
enabled_precisions={torch.float32}, | ||
truncate_double=True, | ||
device=DEVICE, | ||
use_explicit_typing=True, | ||
enable_weight_streaming=True, | ||
) | ||
|
||
# Warm up for 3 iterations | ||
_ = time_generate(trt_model, input_tensors, osl, 3) | ||
|
||
# %% | ||
# Running with automatic budget size | ||
# ---------------------------------- | ||
# | ||
# Once you specify the enable_weight_streaming compile option, automatic budget size is configured. | ||
# This automatic size may not always provide the optimal solution because the automatically determined | ||
# budget lacks insight into the user's specific memory constraints and usage patterns | ||
|
||
# Weight streaming context to get current weight budget information | ||
weight_streaming_ctx = torch_tensorrt.runtime.weight_streaming(trt_model) | ||
# Measure the mean latency of the model with weight streaming | ||
mean_latency = time_generate(trt_model, input_tensors, osl, 1) | ||
# Calculate the percentage of current weight budget used | ||
weight_budget_pct = ( | ||
weight_streaming_ctx.device_budget / weight_streaming_ctx.total_device_budget * 100 | ||
) | ||
print( | ||
f"Set weight streaming budget as {weight_budget_pct}%. {weight_streaming_ctx.device_budget} bytes out of {weight_streaming_ctx.total_device_budget}. mean latency = {mean_latency} ms" | ||
) | ||
|
||
# %% | ||
# Running with weight streaming context manager | ||
# ---------------------------------- | ||
# | ||
# Weight streaming budget can be limited by using weight streaming context manager. | ||
# The permissible range for the budget size is from 0 to ctx.total_device_budget. | ||
# 0 means maximum memory savings occur by using minimum amounts of memory. Value | ||
# equal to ctx.total_device_budget will disable weight streaming. | ||
# If multiple trt engines are created, budgets are distributed proportionally | ||
|
||
# Use a context manager for weight streaming | ||
with torch_tensorrt.runtime.weight_streaming(trt_model) as weight_streaming_ctx: | ||
# Get the total size of streamable weights in the engine | ||
streamable_budget = weight_streaming_ctx.total_device_budget | ||
|
||
# Scenario 1: Automatic weight streaming budget | ||
# Get the automatically determined weight streaming budget | ||
requested_budget = weight_streaming_ctx.get_automatic_weight_streaming_budget() | ||
# Set the device budget to the automatically determined value | ||
weight_streaming_ctx.device_budget = requested_budget | ||
# Measure the mean latency with automatic budget | ||
mean_latency = time_generate(trt_model, input_tensors, osl, 1) | ||
# Calculate the percentage of the weight budget used | ||
weight_budget_pct = ( | ||
weight_streaming_ctx.device_budget | ||
/ weight_streaming_ctx.total_device_budget | ||
* 100 | ||
) | ||
print( | ||
f"Set auto weight streaming budget as {weight_budget_pct}%. {weight_streaming_ctx.device_budget} bytes out of {weight_streaming_ctx.total_device_budget}. mean latency = {mean_latency} ms" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would the values here be same as line 117 ? If we enable_weight_streaming=True but don't set automatic budget explicitly ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my test, it was different. It seems automatic budget is calculated from the current free memory size, it can differ from calling point. |
||
) | ||
|
||
# Scenario 2: Manual 10% weight streaming budget | ||
# Set the budget to 10% of the total streamable weights | ||
requested_budget = int(streamable_budget * 0.1) | ||
weight_streaming_ctx.device_budget = requested_budget | ||
# Measure the mean latency with 10% budget | ||
mean_latency = time_generate(trt_model, input_tensors, osl, 1) | ||
# Calculate the percentage of the weight budget used | ||
weight_budget_pct = ( | ||
weight_streaming_ctx.device_budget | ||
/ weight_streaming_ctx.total_device_budget | ||
* 100 | ||
) | ||
print( | ||
f"Set weight streaming budget as {weight_budget_pct}%. {weight_streaming_ctx.device_budget} bytes out of {weight_streaming_ctx.total_device_budget}. mean latency = {mean_latency} ms" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How fast did it run for you ? It was slow for me in the past and hence I used only 1 warm up iteration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It took 5 seconds in my test with RTX4080.