Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Performance Profiler #495

Closed
wants to merge 33 commits into from
Closed

Conversation

jeromeku
Copy link
Collaborator

@jeromeku jeromeku commented Jul 10, 2024

@msaroufim

Add Performance Profiler

Initial implementation of a performance profiler per #426.

Overview

Primary contribution is a TransformerPerformanceCounter which records data movement and FLOPs across multiple contexts.

Combined with a DeviceSpec, theoretical peak performance / utilization stats are accumulated for each individual context and in aggregate which can be used for Speed of Light / roofline analysis and other downstream performance profiling.

Motivation is to create a lightweight method for collecting useful performance stats using (mostly) torch-native features as a complement to torch.profiler and before diving into tools such as nsight compute.

Details

Below is an example demonstrating the basic API.

from torchao.profiler import CUDADeviceSpec, TransformerPerformanceCounter

# Device spec is a dataclass that contains device info such as name, peak bandwidth, peak FLOPs
# If these fields are not manually specified, they will be automatically populated using
# CUDA runtime functions exposed by `torch.cuda` and `triton.runtime.driver`  
device_spec = CUDADeviceSpec()

# The manager object tracks latency, data movement (bytes), and FLOPs across multiple contexts and
# maintains performance metrics for each context and in aggregate.
manager = TransformerPerformanceCounter(device_spec=device_spec)

# Prefill
with manager.count(label="prefill", num_tokens=x.numel()):
    out = model(encoded_prompt)
# Print recorded stats for "prefill" context
manager.print_summary(labels=["prefill"], show=True) 

# Decode
with manager.count(label="decode", num_tokens=1):
    out = model(out[-1])
# Print recorded stats for "decode" context
manager.print_summary(labels=["decode"], show=True) 

# Print accumulated stats across all contexts
manager.print_summary(show=True) # 
  • CUDADeviceSpec is a lightweight dataclass that models device info

    • auto-populated with device-specific information using torch.cuda and triton.runtime.driver CUDA APIs where possible
    • a hard-coded table is maintained in torchao.profiler.device_spec._AVAILABLE_GPU_SPECS to fill in peak FLOPs, though this should be possible to calculate directly using the cudaDriver API.
  • TransformerPerformanceCounter uses PerformanceCounterMode under the hood to capture data movement and FLOPs

    • PerformanceCounterMode is an extended version of torch.utils.flop_counter.FlopCounterMode which counts data movement and FLOPs by aten.operator and organized by torch.nn.Module via torch.__dispatch__.
  • Metrics are encapsulated by PerformanceStats:

    @dataclass
    class PerformanceStats(DictMixin):
        """
        Data struct that stores performance statistics.
    
        Attrs:
            num_tokens (int): number of tokens processed
            latency (float): latency in seconds
            total_flops (int): total FLOPs
            total_io (int): total data movement in bytes
            flops_summary (Dict[str, int]): summary of FLOPs by module
            io_summary (Dict[str, int]): summary of data movement in bytes by module
            flop_counts (Dict[str, Dict[Any, int]]): FLOP counts by module and operation
            io_counts (Dict[str, Dict[Any, int]]): data movement by module and operation
            device_bandwidth (Optional[float]): device bandwidth in bytes per second
            device_flops_per_s (Optional[float]): device FLOPs per second
    
        Additionally, the following derived properties are available:
            token_throughput (float): number of tokens processed per second
            achieved_flops_per_s (float): achieved FLOPs per second
            achieved_bandwidth (float): achieved data movement in bytes per second
            theoretical_io_latency (Optional[float]): theoretical I/O latency in seconds, set to None if
            no device bandwidth is available.
            theoretical_compute_latency (Optional[float]): theoretical compute latency in seconds, set to None if
            no device FLOPs are available.
        """
  • In addition to the raw / derived metrics, a TransformerPerformanceCounter also has convenience methods for summarizing accumulated stats. From the above example,

    manager.print_summary(show=True)

    will print:

    Performance Summary:
        Latency = 1.42 s
        Tokens
            Total: 7 tokens
            Throughput: 5 tokens/s
        IO
            Total: 26.47 GB
            Throughput: 18.69 GB/s
            Theoretical Latency: 28.28 ms
        FLOPs 
            Total: 92.53 GFLOPs
            Throughput: 65.33 GFLOPs/s
            Theoretical Latency: 2.60 ms
        Utilization
            Bandwidth: 0.0200 %
            FLOPs: 0.0018 %

Tests

See test/profiler/test_device_spec.py and test/profiler/test_performance_counter.py for unit tests for each of these components.

Usage

An end-to-end example of is available in tutorials/profiler.

  • It is a minimal adaptation of the generate.py script from gpt-fast with prettier printing.
  • Note that I've stripped out features such as quantization, tensor parallelism, and speculative decoding as the purpose is to demonstrate usage of TransformerPerformanceCounter.

Running the example for llama2-7b (on an RTX 3090) prints the following, with the outputs from TransformerPerformanceCounter annotated as such, and those from the original gpt-fast script prepended with GPTFast:

GPTFast
Loading model ...
Time to load model: 20.14 seconds

==============================

TransformerPerfCounter
Using DeviceSpec(device_type=cuda, name=NVIDIA GeForce RTX 3090, dtype=torch.bfloat16, bandwidth=936.1GB/s, flops=35.6TFLOPs, vram=25.4GB)

GPTFast
Model Config: ModelArgs(block_size=2048, vocab_size=32000, n_layer=32, n_head=32, dim=4096, intermediate_size=11008, n_local_heads=32, head_dim=128, rope_base=10000, norm_eps=1e-05)
Active params, Total Params: 6607343616, 6738415616

==============================

TransformerPerfCounter Metrics
PREFILL_SEQLEN-6:
  Latency = 1.26 s
  Tokens
    Total: 6 tokens
    Throughput: 5 tokens/s
  IO
    Total: 13.25 GB
    Throughput: 10.54 GB/s
    Theoretical Latency: 14.15 ms
  FLOPs 
    Total: 79.31 GFLOPs
    Throughput: 63.06 GFLOPs/s
    Theoretical Latency: 2.23 ms
  Utilization
    Bandwidth: 0.0113 %
    FLOPs: 0.0018 %

==============================

TransformerPerfCounter Metrics
DECODE_CTX-6_NUM_TOKS-1:
  Latency = 0.16 s
  Tokens
    Total: 1 tokens
    Throughput: 6 tokens/s
  IO
    Total: 13.22 GB
    Throughput: 83.27 GB/s
    Theoretical Latency: 14.13 ms
  FLOPs 
    Total: 13.22 GFLOPs
    Throughput: 83.24 GFLOPs/s
    Theoretical Latency: 0.37 ms
  Utilization
    Bandwidth: 0.0890 %
    FLOPs: 0.0023 %

==============================

GPTFast
Generated text for sample 0: Hello, my name is [Name

GPTFast Sample Metrics
  Time for inference 1: 6 prompt tokens 2 tokens generated, 1.57 sec total, 1.28 tokens/sec
  Bandwidth achieved: 17.22 GB/s

==============================

GPTFast Aggregate Stats
  Average tokens/sec: 1.28
  Memory used: 13.51 GB

==============================

TransformerPerfCounter
Performance Summary:
  Latency = 1.42 s
  Tokens
    Total: 7 tokens
    Throughput: 5 tokens/s
  IO
    Total: 26.47 GB
    Throughput: 18.69 GB/s
    Theoretical Latency: 28.28 ms
  FLOPs 
    Total: 92.53 GFLOPs
    Throughput: 65.33 GFLOPs/s
    Theoretical Latency: 2.60 ms
  Utilization
    Bandwidth: 0.0200 %
    FLOPs: 0.0018 %

Saving performance results to performance_stats.json

TODO

  • Add more examples of usage
  • Automatically calculate peak FLOPs for NVIDIA GPUs
    -More detailed metrics?
    • CUPTI / ncu
    • Combine with torch.profiler

Copy link

pytorch-bot bot commented Jul 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/495

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 9b37dd3 with merge base 12ac498 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 10, 2024
@@ -0,0 +1,257 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like gpt-fast right? In which case we have a fork already in https://github.com/pytorch/ao/tree/main/torchao/_models/llama

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a simplified version with lots of features stripped out and cleaner printing to demonstrate usage of the performance counter.

See the Usage section of the PR and the README for more details.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see! I will be out this week for a friend's wedding but @andrewor14 mind reviewing this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msaroufim @andrewor14

Let me know if this is along the lines of what you had in mind regarding #426.

The core abstractions DeviceSpec and TransformerPerformanceCounter take care of tracking the necessary measurements (achieved BW and FLOPs/s) and detecting device-specific BW and FLOPs/s for MBU and MFU, as seen in the example output above.

Happy to adapt the API however useful.

Copy link
Contributor

@andrewor14 andrewor14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jeromeku, thanks for working on this. It looks great overall! Left a few comments, mostly about code reuse and location. Another thing I'm wondering is how should we make it easy for developers to profile their quantized models? Do you think we should just integrate it into generate.py? Maybe @jerryzh168 and @HDCharles should take a look too.

]


def get_all_base_classes(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it's only used in 1 place. Maybe we should inline?

return [cls.__name__.lower() for cls in inspect.getmro(object.__class__)]


def total_model_params(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like a util method. Should we move this to a utils file, e.g. torchao/profiler/utils.py? Also if these are not meant to be called by the user I would call it _total_model_params to make it private

6: "E",
7: "Z",
8: "Y",
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be a simple array?

@@ -0,0 +1,339 @@
import pytest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels more like a script that users will call instead of a tutorial. I feel this should just live under torchao/profiler/generate.py?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for the README

@@ -0,0 +1,112 @@
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @msaroufim we should not duplicate this code (and model.py). We already copied them from gpt-fast so we should only have 1 version in torchao. Can you merge these with the existing ones under torchao/_models?

Copy link
Member

@msaroufim msaroufim Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeromeku the only controversial thing about this PR, if we could reduce code duplication this would make merging this a no brainer

We're planning a release on Aug 8, do you think you'll have time to land your changes before then?

@@ -0,0 +1,339 @@
import pytest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: probably better to rename the file itself to something that includes profiler in the name as well, otherwise it will be a bit confusing I think

@jerryzh168
Copy link
Contributor

I think this looks great, one thing is that maybe we have to think about where to put the example code (which also requires llama model definition), we have an existing place for llama model: https://github.com/pytorch/ao/tree/main/torchao/_models/llama that also contains eval and benchmark code, everything is under _model folder, I feel it may not make sense, I feel we could move _model out of torchao and then rename it to something that reflects better about what it contains: model + benchmark script + eval script, and make include profiler script as well.

@jeromeku
Copy link
Collaborator Author

@jerryzh168 @msaroufim @andrewor14

Thanks for the feedback -- will make the changes. Been caught up with some other things but should have some free time in the coming days.

@jeromeku
Copy link
Collaborator Author

@andrewor14 @jerryzh168 @msaroufim

Made the following changes:

  • Per this comment, refactored the stray methods in torchao/profiler/__init__.py to torchao/profiler/utils.py
  • Regarding code duplication wrt gpt-fast:
    • Moved tutorials/profiler/generate.py -> torchao/_models/llama/perf_profile.py, which now uses the existing model.py and tokenizer.py
    • Included the README from tutorials/profiler as a module level comment to explain differences from generate.py and explain usage of the profiler.
    • Checked that the script still runs

@jeromeku
Copy link
Collaborator Author

The CI failures aren't related to this PR...

@msaroufim
Copy link
Member

msaroufim commented Jul 27, 2024

I've seen this kind of large scale IMA error when either

  1. some code moves an assumption of test isolation by for example setting some global flag
  2. A segfault or IMA in one test can cause cascading test failure, it's likely originating from triton.runtime.driver if i were to venture a guess

One trick that might help in the meantime is just for this PR, try rebasing to main to see if this issue repros and if it doesn't then change the github action workflow regression test to only run your test and let's see if any issues still pop up

I also left a few misc pieces of feedback

Finally, this PR would benefit from a simple README around where people should plug in their changes to run performance benchmarks. For example this script while significantly simpler did help people run evals without much headaches by just adding yet another if condition https://github.com/pytorch/ao/blob/main/scripts/hf_eval.py

dtype: Optional[torch.dtype] = None
flops_by_dtype: dict = field(default_factory=dict)

def _post_init_check(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example was failing on these asserts for me, one set of failures was because the code was based off an old branch and another set was because it seems like device_spec has non optional parameters like dtype and flops_by_dtype

from torchao.profiler import CUDADeviceSpec, TransformerPerformanceCounter
import torch

# Device spec is a dataclass that contains device info such as name, peak bandwidth, peak FLOPs
# If these fields are not manually specified, they will be automatically populated using
# CUDA runtime functions exposed by `torch.cuda` and `triton.runtime.driver`  
device_spec = CUDADeviceSpec()

# The manager object tracks latency, data movement (bytes), and FLOPs across multiple contexts and
# maintains performance metrics for each context and in aggregate.
manager = TransformerPerformanceCounter(device_spec=device_spec)

# Prefill
with manager.count(label="prefill", num_tokens=x.numel()):
    out = model(encoded_prompt)
# Print recorded stats for "prefill" context
manager.print_summary(labels=["prefill"], show=True) 

# Decode
with manager.count(label="decode", num_tokens=1):
    out = model(out[-1])
# Print recorded stats for "decode" context
manager.print_summary(labels=["decode"], show=True) 

# Print accumulated stats across all contexts
manager.print_summary(show=True) # 

)

# -------------------- Device Spec Tests ------------------- #
DEVICE_NAMES = ["h100 sxm", "a100", "nvidia geforce rtx 4090"]
Copy link
Member

@msaroufim msaroufim Jul 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we only have a10G in CI that might explain some of teh CI failures

And we're exploring L4 instances next since those are cheaper and have fp8 support

@@ -0,0 +1,442 @@
"""
Copy link
Member

@msaroufim msaroufim Jul 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This script is still quite similar to https://github.com/pytorch/ao/blob/main/torchao/_models/llama/generate.py and was hoping we could converge the two or at the very least if we are creating a seperate profile script it should call as many functions from generate as possible

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants