Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: meta-llama/llama3
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: main
Choose a base ref
...
head repository: youngkent/llama3_benchmark
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: main
Choose a head ref
Able to merge. These branches can be automatically merged.
  • 2 commits
  • 3 files changed
  • 1 contributor

Commits on Aug 27, 2024

  1. Add a simple perf test

    Summary:
    Add a simple perf test.
    --enable_torch_compile to turn on torch compile.
    --test_interations to control how many batch to run
    
    Test Plan:
    PYTHONPATH="/home/ktong/llama3_benchmark:$PYTHONPATH" torchrun --nproc_per_node=1 -m simple_llama3_perf_test --ckpt_dir <model path> --tokenizer_path <tokenizer path> --max_batch_size 32
    
    Reviewers:
    
    Subscribers:
    
    Tasks:
    
    Tags:
    youngkent committed Aug 27, 2024
    Copy the full SHA
    817b780 View commit details

Commits on Aug 29, 2024

  1. Upgrade to llama3.1 model

    Summary:
    
    Test Plan:
    
    Reviewers:
    
    Subscribers:
    
    Tasks:
    
    Tags:
    youngkent committed Aug 29, 2024
    Copy the full SHA
    585de1c View commit details
Showing with 135 additions and 5 deletions.
  1. +7 −2 llama/generation.py
  2. +40 −3 llama/model.py
  3. +88 −0 simple_llama3_perf_test.py
9 changes: 7 additions & 2 deletions llama/generation.py
Original file line number Diff line number Diff line change
@@ -41,6 +41,7 @@ def build(
max_batch_size: int,
model_parallel_size: Optional[int] = None,
seed: int = 1,
enable_torch_compile: bool = False,
) -> "Llama":
"""
Build a Llama instance by initializing and loading a model checkpoint.
@@ -109,8 +110,12 @@ def build(
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")

return Llama(model, tokenizer)
if enable_torch_compile:
print(f"Torch compiling model ...")
compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True)
return Llama(compiled_model, tokenizer)
else:
return Llama(model, tokenizer)

def __init__(self, model: Transformer, tokenizer: Tokenizer):
self.model = model
43 changes: 40 additions & 3 deletions llama/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.

import math
from dataclasses import dataclass
from typing import Optional, Tuple

import fairscale.nn.model_parallel.initialize as fs_init
@@ -14,7 +20,7 @@
VocabParallelEmbedding,
)
from torch import nn

from dataclasses import dataclass

@dataclass
class ModelArgs:
@@ -27,6 +33,7 @@ class ModelArgs:
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
rope_theta: float = 500000
use_scaled_rope: bool = False

max_batch_size: int = 32
max_seq_len: int = 2048
@@ -46,9 +53,38 @@ def forward(self, x):
return output * self.weight


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
def apply_scaling(freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
@@ -272,6 +308,7 @@ def __init__(self, params: ModelArgs):
params.dim // params.n_heads,
params.max_seq_len * 2,
params.rope_theta,
params.use_scaled_rope,
)

@torch.inference_mode()
88 changes: 88 additions & 0 deletions simple_llama3_perf_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.

import time
from typing import List

import fire

from llama import Llama


def run_once(
generator,
prompts,
max_gen_len,
temperature,
top_p,
print_output,
):
st = time.time()
results = generator.text_completion(
prompts,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
)
if print_output:
for prompt, result in zip(prompts, results):
print(prompt)
print(f"> {result['generation']}")
print("\n==================================\n")
return time.time() - st


def main(
ckpt_dir: str,
tokenizer_path: str,
temperature: float = 0.6,
top_p: float = 0.9,
max_seq_len: int = 1024,
max_gen_len: int = 64,
max_batch_size: int = 32,
print_output: bool = False,
test_iterations: int = 10,
enable_torch_compile: bool = False,
):

generator = Llama.build(
ckpt_dir=ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
enable_torch_compile=enable_torch_compile,
)

prompts: List[str] = ["I believe the meaning of life is"] * max_batch_size

print(f"Warming up the model ...")
for _ in range(3):
run_once(generator, prompts, max_gen_len, temperature, top_p, print_output)

print(f"Measuring perf ...")
latencies = []
throughputs = []
for i in range(test_iterations):
latency = run_once(
generator,
prompts,
max_gen_len,
temperature,
top_p,
print_output,
)
latencies.append(latency)
throughput = len(prompts) / latency
throughputs.append(throughput)

print(
f"Batch completed with total latency: {latency:.3f}s, QPS: {throughput:.3f}"
)

print(
f"Average latency: {sum(latencies) / len(latencies):.3f}s, average QPS: {sum(throughputs) / len(throughputs):.3f}"
)


if __name__ == "__main__":
fire.Fire(main)