Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement FP32 accumulation for matmul #3110

Merged
merged 81 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 65 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
2ea181a
chore: add gpt2 example
peri044 Jun 13, 2024
37b65a5
chore: add llama2 example
peri044 Jun 13, 2024
bd12b12
Merge branch 'main' into llm_examples_main
peri044 Jun 13, 2024
4a9f73e
Merge branch 'main' into llm_examples_main
peri044 Jun 14, 2024
0387d0b
Merge branch 'main' into llm_examples_main
peri044 Jun 14, 2024
6193939
chore: updates
peri044 Jun 14, 2024
9d3296e
Merge branch 'main' into llm_examples_main
peri044 Jun 14, 2024
84fc49c
Merge branch 'main' into llm_examples_main
peri044 Jun 18, 2024
ff17d91
chore: rebase
peri044 Jun 18, 2024
8e6ba26
Merge branch 'llm_examples_main' of github.com:pytorch/TensorRT into …
peri044 Jun 24, 2024
67ec408
Merge branch 'main' into llm_examples_main
peri044 Jun 25, 2024
9af8e39
chore: remove aten.full decomposition
peri044 Jun 25, 2024
50d4096
chore: fix expand DS support
peri044 Jun 25, 2024
59febf5
chore: minor fix
peri044 Jun 26, 2024
c3e4382
chore: updates
peri044 Jun 26, 2024
0673db4
chore: add testcase
peri044 Jun 26, 2024
0b62f8f
Merge branch 'main' into full
peri044 Jun 26, 2024
54f6410
Merge branch 'full' into fix_expand_ds
peri044 Jun 26, 2024
ae3d6b2
Merge branch 'fix_expand_ds' into llm_examples_main
peri044 Jun 26, 2024
4464fd5
chore: updates
peri044 Jun 26, 2024
63b13cf
chore: updates
peri044 Jun 28, 2024
3d10b92
Merge branch 'main' into llm_examples_main
peri044 Jun 28, 2024
e97a94f
chore: updates
peri044 Jul 10, 2024
4f503a8
chore: updates
peri044 Jul 11, 2024
5ecf63e
chore: rebase
peri044 Jul 11, 2024
0d00d8c
chore: updates
peri044 Jul 11, 2024
8099003
chore: updates
peri044 Jul 11, 2024
457f706
chore: updates
peri044 Jul 11, 2024
ce3b2f8
chore: updates
peri044 Jul 11, 2024
d8acadc
chore: updates
peri044 Jul 12, 2024
262c87d
chore: updates
peri044 Jul 12, 2024
bb94dfd
chore: rebase
peri044 Jul 17, 2024
736b839
chore: updates
peri044 Jul 17, 2024
313380e
chore: bug fixes
peri044 Jul 18, 2024
1057d83
chore: updates
peri044 Jul 19, 2024
bfd0cf2
chore: fixes
peri044 Jul 20, 2024
17ddb31
chore: updates
peri044 Jul 20, 2024
88be4fa
chore: add torch compile gpt2 example
peri044 Jul 22, 2024
df825ab
chore: updates
peri044 Jul 22, 2024
ff07295
chore: add timing calculation
peri044 Jul 24, 2024
857b0aa
Merge branch 'main' into llm_examples_main
peri044 Jul 24, 2024
8fae56b
Merge branch 'main' into llm_examples_main
peri044 Jul 29, 2024
d483718
chore: rebase
peri044 Jul 31, 2024
397e4bc
Merge branch 'main' into llm_examples_main
peri044 Aug 5, 2024
6c9b9fe
chore: updates
peri044 Aug 5, 2024
6313b1c
chore: updates
peri044 Aug 9, 2024
d608cc5
chore: rebase
peri044 Aug 9, 2024
1327782
chore: rebase fixes
peri044 Aug 9, 2024
0980778
chore: updates
peri044 Aug 9, 2024
94b2ba1
chore: updates
peri044 Aug 9, 2024
2b1db29
chore: updates
peri044 Aug 9, 2024
9f606fc
chore: updates
peri044 Aug 9, 2024
0cf23be
Merge branch 'main' into llm_examples_main
peri044 Aug 14, 2024
3228c57
chore: Update perf tooling with support for HF models (#3034)
peri044 Aug 15, 2024
6786f0e
chore: updates
Aug 15, 2024
e4873d0
chore: updates
peri044 Aug 19, 2024
a725ce0
Merge branch 'main' into llm_examples_main
peri044 Aug 19, 2024
bb10de4
feat: lowering replace aten.full_like with aten.full
chohk88 Aug 12, 2024
1527aa0
chore: minor linting
chohk88 Aug 12, 2024
67e33c3
chore: updates
peri044 Aug 19, 2024
5627c1a
Merge branch 'llm_examples_main' of github.com:pytorch/TensorRT into …
peri044 Aug 19, 2024
7be8604
chore: updates
peri044 Aug 21, 2024
4d75a2e
Merge branch 'main' into llm_examples_main
peri044 Aug 21, 2024
0ab0dbf
feat: add fp32 accumulation option for matmul layer
peri044 Aug 21, 2024
3c815f8
chore: updates
Aug 28, 2024
5617c0a
chore: Bump TRT version to 10.3.0.26 (#3071)
zewenli98 Aug 24, 2024
213526e
chore: updates
peri044 Aug 30, 2024
c193593
chore : updates
peri044 Aug 30, 2024
0de0b16
chore: updates
peri044 Sep 24, 2024
a90191d
chore: rebase with main
peri044 Sep 24, 2024
71e33cb
chore: updates
peri044 Sep 26, 2024
4257b1e
chore: updates
peri044 Sep 30, 2024
619a39a
chore: updates
peri044 Oct 1, 2024
8c0b9c6
chore: trunc_fiv fix
peri044 Oct 7, 2024
b6261f9
chore: update result
peri044 Oct 7, 2024
ebdfe8f
fix: add model.half() for llama2
peri044 Oct 7, 2024
61ec948
chore: address review comments
peri044 Oct 8, 2024
dd27a54
chore: address review comments
peri044 Oct 8, 2024
b2e5244
chore: add docs
peri044 Oct 8, 2024
7ddd637
chore: updates
peri044 Oct 8, 2024
4529717
chore: sign bug fix
peri044 Oct 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions examples/dynamo/torch_export_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
.. _torch_export_gpt2:

Compiling GPT2 using the Torch-TensorRT with dynamo backend
==========================================================

This interactive script is intended as a sample of the Torch-TensorRT workflow with dynamo backend on a GPT2 model."""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import export_llm, generate

# %%

# Define the parameters and initialize the model
MAX_TOKENS = 32
DEVICE = torch.device("cuda:0")

# Define the GPT2 model from hugging face
# kv_cache is not supported in Torch-TRT currently.
# CPU is used here so that GPU memory is reserved for TRT compilation.
with torch.no_grad():
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = (
AutoModelForCausalLM.from_pretrained(
"gpt2",
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
attn_implementation="eager",
)
.eval()
.half()
)

# %%
# Tokenize a sample input prompt and get pytorch model outputs
prompt = "I enjoy walking with my cute dog"
model_inputs = tokenizer(prompt, return_tensors="pt")
input_ids = model_inputs["input_ids"]

# Auto-regressive generation loop for greedy decoding using PyTorch model
# We use a custom generate function which is very similar to the huggingface one.
pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)


# %%
# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Export the GPT2 model into an ExportedProgram which is input of TRT compilation
gpt2_ep = export_llm(model, input_ids, max_seq_len=1024)
with torch_tensorrt.logging.debug():
trt_model = torch_tensorrt.dynamo.compile(
gpt2_ep,
inputs=[input_ids],
enabled_precisions={torch.float16},
truncate_double=True,
device=DEVICE,
disable_tf32=True,
use_strong_types=False,
use_fp32_acc=True,
)

# Auto-regressive generation loop for greedy decoding using TensorRT model
# We use a custom generate function which is very similar to the huggingface one.
# Move inputs to GPU
input_ids = input_ids.to(DEVICE)
trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)

# %%
# Decode the output sentences of PyTorch and TensorRT
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
print("=============================")
print(
"Pytorch model generated text: ",
tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True),
)
print("=============================")
print(
"TensorRT model generated text: ",
tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True),
)

# %%
# The output sentences should look like
# =============================
# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
# =============================
# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
90 changes: 90 additions & 0 deletions examples/dynamo/torch_export_llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
.. _torch_export_llama2:

Compiling Llama2 using the Torch-TensorRT with dynamo backend
==========================================================

This interactive script is intended as a sample of the Torch-TensorRT workflow with dynamo backend on a Llama2 model."""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import export_llm, generate

# %%
# Define the parameters and initialize the model
MAX_TOKENS = 32
DEVICE = torch.device("cuda:0")

# Define the Llama2 model from hugging face
# kv_cache is not supported in Torch-TRT currently.
# CPU is used here so that GPU memory is reserved for TRT compilation.
llama_path = "meta-llama/Llama-2-7b-chat-hf"
with torch.no_grad():
model = AutoModelForCausalLM.from_pretrained(
llama_path, use_cache=False, attn_implementation="eager"
).eval()

tokenizer = AutoTokenizer.from_pretrained(llama_path)

# %%
# Tokenize a sample input prompt and get pytorch model outputs
prompt = "What is dynamic programming?"
model_inputs = tokenizer(prompt, return_tensors="pt")
input_ids = model_inputs.input_ids

# Auto-regressive generation loop for greedy decoding using PyTorch model
# We use a custom generate function which is very similar to the huggingface one.
pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)

# %%
# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Export the llama2 model into an ExportedProgram which is input of TRT compilation
llama2_ep = export_llm(model, input_ids, max_seq_len=64)
trt_model = torch_tensorrt.dynamo.compile(
llama2_ep,
inputs=[input_ids],
enabled_precisions={torch.float32},
min_block_size=1,
truncate_double=True,
device=DEVICE,
disable_tf32=True,
)

# Auto-regressive generation loop for greedy decoding using TensorRT model
# We use a custom generate function which is very similar to the huggingface one.
# Move inputs to GPU
input_ids = input_ids.to(DEVICE)
trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)

# %%
# Decode the output sentences of PyTorch and TensorRT
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
print("=============================")
print(
"Pytorch model generated text: ",
tokenizer.batch_decode(
pyt_gen_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0],
)
print("=============================")
print(
"TensorRT model generated text: ",
tokenizer.batch_decode(
trt_gen_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0],
)

# %%
# The output sentences should look like
# =============================
# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
# =============================
# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
63 changes: 63 additions & 0 deletions examples/dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
from transformers import StoppingCriteriaList
from transformers.generation.stopping_criteria import (
EosTokenCriteria,
MaxLengthCriteria,
)


def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
"""
Exports the LLM model into an ExportedProgram with dynamic shapes.
In the case of guard failures due to some PyTorch kernel implements, we also
try to re-export the graph by expressing them as runtime assert nodes
"""
with torch.no_grad():
# max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604
seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len)
try:
print("Trying to export the model using torch.export.export()..")
# strict=False only enables aotautograd tracing and excludes dynamo.
ep = torch.export.export(
model, (inputs,), dynamic_shapes=({1: seq_len},), strict=False
)
except:
print(
"Trying torch.export._trace._export to trace the graph since torch.export.export() failed"
)
# This API is used to express the constraint violation guards as asserts in the graph.
ep = torch.export._trace._export(
model,
(inputs,),
dynamic_shapes=({1: seq_len},),
strict=False,
allow_complex_guards_as_runtime_asserts=True,
)

return ep


def generate(model, input_seq, max_tokens, eos_token_id):
"""
Greedy decoding of the model. This generates up to max_tokens.
"""
# Max length of output seq = current input_seq length + max_tokens allowed to generate
max_output_seq_length = input_seq.shape[1] + max_tokens
stopping_criteria = StoppingCriteriaList(
[
MaxLengthCriteria(max_length=max_output_seq_length),
EosTokenCriteria(eos_token_id=eos_token_id),
]
)

while True:
outputs = model(input_seq)
logits = outputs.logits
next_token_logits = logits[:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1)
# TODO: Handle batch in this check
if stopping_criteria(input_seq, logits).item():
break

return input_seq
57 changes: 32 additions & 25 deletions py/torch_tensorrt/dynamo/_DryRunTracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ class PerSubgraphData:
Args:
subgraph_name (str): Name of the subgraph in the GraphModule
subgraph_op_count (int): Number of operations in the subgraph
subgraph_input_shapes (Any): Shapes of input Tensors of the subgraph
subgraph_input_dtypes (Any): Input data types of the subgraph
subgraph_output_shapes (Any): Shapes of output Tensors of the subgraph
subgraph_output_dtypes (Any): Output data types of the subgraph
input_shapes (Any): Shapes of input Tensors of the subgraph
input_dtypes (Any): Input data types of the subgraph
output_shapes (Any): Shapes of output Tensors of the subgraph
output_dtypes (Any): Output data types of the subgraph
"""

subgraph_name: str = ""
subgraph_op_count: int = 0
subgraph_input_shapes: Any = field(default_factory=list)
subgraph_input_dtypes: Any = field(default_factory=list)
subgraph_output_shapes: Any = field(default_factory=list)
subgraph_output_dtypes: Any = field(default_factory=list)
input_shapes: Any = field(default_factory=list)
input_dtypes: Any = field(default_factory=list)
output_shapes: Any = field(default_factory=list)
output_dtypes: Any = field(default_factory=list)


@dataclass
Expand All @@ -41,10 +41,10 @@ class DryRunTracker:
Args:
total_ops_in_graph (int): Total number of operators in graph
supported_ops_in_graph (int): Number of supported operators in graph
graph_input_shapes (Any): Shapes of input Tensors of the graph
graph_input_dtypes (Any): Input data types of the graph
graph_output_shapes (Any): Shapes of output Tensors of the graph
graph_output_dtypes (Any): Output data types of the graph
input_shapes (Any): Shapes of input Tensors of the graph
input_dtypes (Any): Input data types of the graph
output_shapes (Any): Shapes of output Tensors of the graph
output_dtypes (Any): Output data types of the graph
per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class
tensorrt_graph_count (int): Number of TensorRT engines to be generated
compilation_settings (CompilationSettings): User Compilation Settings
Expand All @@ -54,10 +54,10 @@ class DryRunTracker:

total_ops_in_graph: int = 0
supported_ops_in_graph: int = 0
graph_input_shapes: Any = field(default_factory=list)
graph_input_dtypes: Any = field(default_factory=list)
graph_output_shapes: Any = field(default_factory=list)
graph_output_dtypes: Any = field(default_factory=list)
input_shapes: Any = field(default_factory=list)
input_dtypes: Any = field(default_factory=list)
output_shapes: Any = field(default_factory=list)
output_dtypes: Any = field(default_factory=list)
per_subgraph_data: List[PerSubgraphData] = field(default_factory=list)
tensorrt_graph_count: int = 0
compilation_settings: CompilationSettings = field(
Expand Down Expand Up @@ -111,7 +111,7 @@ def dryrun_stats_display(
formatted_stats += " " * 2 + "Graph Structure:\n\n"
formatted_stats += (
" " * 3
+ f"Inputs: {input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}\n"
+ f"Inputs: {input_formatter(dryrun_tracker.input_shapes, dryrun_tracker.input_dtypes)}\n"
)

for i, trt_subgraph_data in enumerate(dryrun_tracker.per_subgraph_data):
Expand All @@ -122,21 +122,21 @@ def dryrun_stats_display(
)
formatted_stats += (
" " * 5
+ f"Engine Inputs: {input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}\n"
+ f"Engine Inputs: {input_formatter(trt_subgraph_data.input_shapes, trt_subgraph_data.input_dtypes)}\n"
)
formatted_stats += (
" " * 5
+ f"Number of Operators in Engine: {trt_subgraph_data.subgraph_op_count}\n"
)
formatted_stats += (
" " * 5
+ f"Engine Outputs: {input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}\n"
+ f"Engine Outputs: {input_formatter(trt_subgraph_data.output_shapes, trt_subgraph_data.output_dtypes)}\n"
)

formatted_stats += " " * 4 + "...\n"
formatted_stats += (
" " * 3
+ f"Outputs: {input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}\n"
+ f"Outputs: {input_formatter(dryrun_tracker.output_shapes, dryrun_tracker.output_dtypes)}\n"
)

# Print aggregate statistics about the graph structure, including recommended "min_block_size" options
Expand Down Expand Up @@ -225,11 +225,18 @@ def input_formatter(shapes: Any, dtypes: Any) -> str:

def input_formatter_helper(shapes: Any, dtypes: Any) -> str:
"""Helper for input formatter"""
# Base case - single shape, single dtype
if isinstance(shapes, tuple) and all(isinstance(elt, int) for elt in shapes):
return f"Tensor: {shapes}@{str(dtypes)[6:]}, "

# Base case - dynamic shape, single dtype
# Base case 1 - single static/dynamic shape, single dtype
if isinstance(shapes, tuple) and all(isinstance(elt, (int, tuple)) for elt in shapes):
input_shape_string = "Tensor: ("
for elt in shapes:
if isinstance(elt, tuple):
input_shape_string+= f"(min={elt[0]}, max={elt[1]}), "
else:
input_shape_string+= f"{elt}, "
input_shape_string = input_shape_string[:-2] + ")" + f"@{str(dtypes)[6:]}, "
return input_shape_string

# Base case 2 - dynamic shape, single dtype
elif (
isinstance(shapes, dict)
and len(shapes) == 3
Expand Down
Loading
Loading