From 6b90ac7f3d5f720a785a97f7f30b161912a40438 Mon Sep 17 00:00:00 2001 From: Archana Ramalingam <98564406+archana-ramalingam@users.noreply.github.com> Date: Wed, 16 Oct 2024 10:30:46 -0700 Subject: [PATCH 1/3] Revert "[sharktank] Evaluation - Add Perplexity test" (#285) Reverts nod-ai/SHARK-Platform#233 --- .github/workflows/eval_test.yaml | 60 --- docs/model_cookbook.md | 10 - requirements.txt | 1 - .../sharktank/evaluate/data/eval_prompts.txt | 12 - sharktank/sharktank/evaluate/perplexity.py | 326 ----------------- .../sharktank/evaluate/perplexity_prefill.py | 276 -------------- sharktank/sharktank/utils/load_llm.py | 212 ----------- sharktank/sharktank/utils/tokenizer.py | 50 +-- sharktank/tests/evaluate/perplexity_test.py | 345 ------------------ 9 files changed, 13 insertions(+), 1279 deletions(-) delete mode 100644 .github/workflows/eval_test.yaml delete mode 100644 sharktank/sharktank/evaluate/data/eval_prompts.txt delete mode 100644 sharktank/sharktank/evaluate/perplexity.py delete mode 100644 sharktank/sharktank/evaluate/perplexity_prefill.py delete mode 100644 sharktank/sharktank/utils/load_llm.py delete mode 100644 sharktank/tests/evaluate/perplexity_test.py diff --git a/.github/workflows/eval_test.yaml b/.github/workflows/eval_test.yaml deleted file mode 100644 index 9410aeb36..000000000 --- a/.github/workflows/eval_test.yaml +++ /dev/null @@ -1,60 +0,0 @@ -name: Evaluation Tests - -on: - workflow_dispatch: - schedule: - # Weekdays nightly at 07:00 UTC = 23:00 PST / 00:00 PDT. - - cron: "0 7 * * 1-5" - -concurrency: - # A PR number if a pull request and otherwise the commit hash. This cancels - # queued and in-progress runs for the same PR (presubmit) or commit - # (postsubmit). The workflow name is prepended to avoid conflicts between - # different workflows. - group: ${{ github.workflow }}-${{ github.event.number || github.sha }} - cancel-in-progress: true - -jobs: - test_perplexity: - name: "Evaluation Tests - perplexity" - strategy: - matrix: - version: [3.11] - os: [ubuntu-latest, windows-latest] - fail-fast: false - runs-on: ${{matrix.os}} - defaults: - run: - shell: bash - env: - PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" - steps: - - name: "Setting up Python" - id: setup_python - uses: actions/setup-python@v3 - with: - python-version: ${{matrix.version}} - - - name: "Checkout Code" - uses: actions/checkout@v3 - - - name: Cache Pip Packages - uses: actions/cache@v4 - id: cache-pip - with: - path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} - - - name: Install pip deps - run: | - python -m pip install --no-compile --upgrade pip - # Note: We install in three steps in order to satisfy requirements - # from non default locations first. Installing the PyTorch CPU - # wheels saves multiple minutes and a lot of bandwidth on runner setup. - pip install --no-compile -r pytorch-cpu-requirements.txt - pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ - -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" - pip install --no-compile -r requirements.txt -e sharktank/ shortfin/ - - - name: Run perplexity test - run: pytest sharktank/tests/evaluate/perplexity_test.py diff --git a/docs/model_cookbook.md b/docs/model_cookbook.md index 1190a723f..becf40820 100644 --- a/docs/model_cookbook.md +++ b/docs/model_cookbook.md @@ -257,16 +257,6 @@ iree-run-module \ --parameters=model=/tmp/open_llama_3b_v2/open-llama-3b-v2-f16.gguf ``` -## Evaluation pipeline - -Run perplexity test: - -```bash -python -m sharktank.evaluate.perplexity \ - --gguf-file=llama8b_f16.gguf \ - --tokenizer-config-json=tokenizer_config.json -``` - ## Generating data for llama models ```bash diff --git a/requirements.txt b/requirements.txt index a54088849..0198314f8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,6 @@ onnx==1.15.0 huggingface-hub==0.22.2 transformers==4.40.0 sentencepiece==0.2.0 -datasets==3.0.0 # It is expected that you have installed a PyTorch version/variant specific # to your needs, so we only include a minimum version spec. diff --git a/sharktank/sharktank/evaluate/data/eval_prompts.txt b/sharktank/sharktank/evaluate/data/eval_prompts.txt deleted file mode 100644 index fa3337c80..000000000 --- a/sharktank/sharktank/evaluate/data/eval_prompts.txt +++ /dev/null @@ -1,12 +0,0 @@ -Robert Boulter is an English film, television and theatre actor. -Robert Boulter had a guest-starring role on the television series "The Bill" in 2000. -Du Fu was a prominent Chinese poet of the Tang dynasty. -Along with Li Bai (Li Po), Du Fu is frequently called the greatest of the Chinese poets. -The Ise-class battleships were a pair of dreadnought battleships built for the Imperial Japanese Navy (IJN) during World War I. -Originally intended to be repeats of the preceding Fusō class, the Ise-class battleships were redesigned before construction began. Both ships carried supplies for the survivors of the Great Kantō earthquake in 1923. -They were modernized in 1934-37 with improvements to their armour and machinery and a rebuilt superstructure in the pagoda mast style. Afterwards they played a minor role in the Second Sino-Japanese War. -Richard Gale "Dick" Rifenburg (August 21, 1926-December 5, 1994) was an American football player and a pioneering television broadcaster for the forerunner to WIVB-TV in Buffalo. -Rifenburg played college football for the University of Michigan Wolverines in 1944 and from 1946 to 1948. He was a consensus selection at end on the 1948 College Football All-America Team. -Rifenburg played professionally in the National Football League (NFL) with the Detroit Lions for one season in 1950. After retiring from football he settled in Buffalo and became a sports broadcaster. -An oxaziridine is an organic molecule that features a three-membered heterocycle containing oxygen, nitrogen, and carbon. In their largest application, oxazidines are intermediates in the industrial production of hydrazine. -Oxaziridine derivatives are also used as specialized reagents in organic chemistry for a variety of oxidations, including alpha hydroxylation of enolates, epoxidation and aziridination of olefins, and other heteroatom transfer reactions. diff --git a/sharktank/sharktank/evaluate/perplexity.py b/sharktank/sharktank/evaluate/perplexity.py deleted file mode 100644 index be794788b..000000000 --- a/sharktank/sharktank/evaluate/perplexity.py +++ /dev/null @@ -1,326 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import sys -import logging -import time -from datetime import timedelta -import json -import numpy as np -from tqdm import tqdm - -from datasets import load_dataset - -import torch -from torch.nn import CrossEntropyLoss - -from sharktank.layers import * -from sharktank.types import * - -from sharktank.models.llama.llama import * -from sharktank.models.mixtral.mixtral import * -from sharktank.models.grok.grok import * - -from ..models.llama.sharding import shard_theta - -from sharktank.utils import cli -from sharktank.utils.load_llm import * - -log_levels = { - "info": logging.INFO, - "debug": logging.DEBUG, -} -logger = logging.getLogger("eval") - -logger.setLevel(log_levels["info"]) - -logger.root.handlers[0].setFormatter( - logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s") -) - -__all__ = ["Perplexity", "run_perplexity"] - - -class Perplexity: - """ - Perplexity (PPL) is one of the most common metrics for evaluating language models. - It is defined as the exponentiated average negative log-likelihood of a sequence, - calculated with exponent base `e`. - - For more information, see https://huggingface.co/docs/transformers/perplexity - """ - - def __init__( - self, - prompts: list, - device, - kv_cache_type, - ): - self.prompts = prompts - self.add_start_token = False - self.batch_size = 16 - self.bs = len(prompts) - self.device = device - self.kv_cache_type = kv_cache_type - self.activation_dtype = torch.float32 - self.attention_dtype = torch.float32 - - def timeit(func): - def wrapper(*args, **kwargs): - start = time.time() - result = func(*args, **kwargs) - end = time.time() - seconds = end - start - time_taken = abs(timedelta(seconds=round(seconds))) - - if seconds < 1: - time_taken = f" {seconds * 1000} ms" - - func_name = func.__name__ - if func_name == "get_perplexity": - func_name = "Total time" - logger.info(f" {func_name}: {time_taken}") - return result - - return wrapper - - def print_token_comparison(self, i): - if i <= self.max_prompt_length: - batch_predicted_token_id = [[i[-1]] for i in self.batch.results] - batch_predicted_token = self.generator.tokenizer.decode( - batch_predicted_token_id - ) - logger.debug(f"Predicted:") - logger.debug(f"{batch_predicted_token}") - logger.debug(f"{batch_predicted_token_id}") - - expected_token_id = self.token_ids[:, i + 1 : i + 2].tolist() - expected_token = self.generator.tokenizer.decode(expected_token_id) - logger.debug(f"Expected:") - logger.debug(f"{expected_token}") - logger.debug(f"{expected_token_id}") - - @timeit - def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kernel): - - config = LlamaModelConfig( - hp=configs.LlamaHParams.from_gguf_props(dataset.properties), - block_seq_stride=16, - kv_cache_type=self.kv_cache_type, - device=self.device, - activation_dtype=self.activation_dtype, - attention_dtype=self.attention_dtype, - ) - - if tensor_parallelism_size > 1: - dataset.root_theta = shard_theta(dataset.root_theta, config) - - theta = dataset.root_theta - - if config.hp.expert_count: - if config.hp.model_arch == "grok": - model = PagedGrokModelV1(theta, config) - else: - model = PagedMixtralModelV1(theta, config) - else: - model = PagedLlamaModelV1(theta, config) - - self.generator = TorchGenerator(model, tokenizer) - - @timeit - def get_logits(self): - - token_ids, seq_lens = self.generator.tokenizer.encode( - self.prompts, - pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, - add_start_token=self.add_start_token, - ) - - logger.info(f" Prompts for Evaluation:") - for idx, prompt in enumerate(self.prompts): - logger.info( - f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n" - ) - - self.max_prompt_length = max(seq_lens) - - self.token_ids = torch.tensor(token_ids, device=self.device) - self.attention_mask = ( - (self.token_ids != 0).int().detach().clone().to(self.device) - ) - - is_first_token = True - start = 0 - for i in tqdm( - range(start, self.max_prompt_length - 1), - desc="eval: Calculating logits", - ): - logger.debug(f"Iteration: {i}") - - if is_first_token: - - token_batch = self.token_ids[:, : i + 1] - logger.debug(f"Prefill:") - - logger.debug("Input:") - logger.debug(f"{self.generator.tokenizer.decode(token_batch)}") - - token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens( - token_ids=token_batch.tolist(), - pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, - ) - - logger.debug(f"{token_batch}") - - token_batch = torch.tensor(token_batch, device=self.device) - seq_lens_batch = torch.tensor(seq_lens_batch, device=self.device) - - self.batch = self.generator.begin_eval_batch( - token_batch=token_batch, - seq_lens_batch=seq_lens_batch, - bs=self.bs, - ) - - self.batch.prefill() - self.out_logits = self.batch.prefill_logits[:, 0:1, :] - is_first_token = False - - self.print_token_comparison(i) - - else: - token_batch = self.token_ids[:, i : i + 1] - - logger.debug("Decode:") - - logger.debug("Input:") - logger.debug(f"{self.generator.tokenizer.decode(token_batch)}") - logger.debug(f"{token_batch.tolist()}") - - self.batch.decode(token_batch=token_batch) - self.out_logits = torch.cat( - (self.out_logits, self.batch.decode_logits), 1 - ) - - self.print_token_comparison(i) - - pad_logits_shape = self.token_ids.shape[1] - self.out_logits.shape[1] - - self.pad_logits = torch.zeros( - self.out_logits.shape[0], pad_logits_shape, self.out_logits.shape[2] - ) - - self.out_logits = torch.cat((self.out_logits, self.pad_logits), 1).to( - self.device - ) - - @timeit - def compute_perplexity(self): - loss_fct = CrossEntropyLoss(reduction="none") - - ## perplexity = e ^ (sum(losses) / num_tokenized_tokens) - crossentropy_loss = ( - loss_fct(self.out_logits.transpose(1, 2), self.token_ids) - * self.attention_mask - ).sum(1) - crossentropy_loss = torch.tensor(crossentropy_loss.tolist()) - perplexity_batch = torch.exp( - crossentropy_loss / self.attention_mask.sum(1) - ).tolist() - - return { - "perplexities": perplexity_batch, - "mean_perplexity": np.mean(perplexity_batch), - } - - @timeit - def get_perplexity(self): - - self.get_logits() - - self.out_logits = self.out_logits[..., :-1, :].contiguous() - self.token_ids = self.token_ids[..., 1:].contiguous() - self.attention_mask = self.attention_mask[..., 1:].contiguous() - - logger.debug(f"Final Logits shape: {self.out_logits.shape}") - logger.debug(f"Token ids: {self.token_ids}, \n{self.token_ids.shape}") - logger.debug( - f"Mask shape: {self.attention_mask}, \n{self.attention_mask.shape}" - ) - - assert self.token_ids.shape == self.out_logits.shape[0:2] - - return self.compute_perplexity() - - -def run_perplexity( - prompts: list[str], - dataset, - tokenizer, - device, - kv_cache_type, - tensor_parallelism_size, - attention_kernel, -): - perplexity = Perplexity(prompts=prompts, device=device, kv_cache_type=kv_cache_type) - - perplexity.load_model(dataset, tokenizer, tensor_parallelism_size, attention_kernel) - ppl = perplexity.get_perplexity() - - return ppl - - -def main(argv): - parser = cli.create_parser() - parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") - parser.add_argument("--device", help="Torch device (or default)") - parser.add_argument( - "--attention-kernel", - type=str, - default="decomposed", - choices=["decomposed", "torch_sdpa"], - ) - - parser.add_argument( - "--tensor-parallelism-size", - type=int, - default=1, - help="Number of devices for tensor parallel sharding.", - ) - - cli.add_input_dataset_options(parser) - cli.add_tokenizer_options(parser) - args = cli.parse(parser, args=argv) - - device = torch.device(args.device) if args.device else None - kv_cache_type = args.kv_cache_type - dataset = cli.get_input_dataset(args) - tokenizer = cli.get_tokenizer(args) - - input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][ - :20 - ] - - # Ignore prompts that are: empty, less than 5 words or a title. - input_texts = [ - s for s in input_texts if s != "" and len(s.split()) > 5 and s.count("=") < 2 - ] - - ppl = run_perplexity( - prompts=input_texts, - dataset=dataset, - tokenizer=tokenizer, - device=device, - kv_cache_type=kv_cache_type, - tensor_parallelism_size=args.tensor_parallelism_size, - attention_kernel=args.attention_kernel, - ) - - logger.info(f"\n{json.dumps(ppl, indent=2)}") - return ppl - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/sharktank/sharktank/evaluate/perplexity_prefill.py b/sharktank/sharktank/evaluate/perplexity_prefill.py deleted file mode 100644 index 2bb785801..000000000 --- a/sharktank/sharktank/evaluate/perplexity_prefill.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import sys -import logging -import time -from datetime import timedelta - -import json -import numpy as np -from tqdm import tqdm - -import torch -from torch.nn import CrossEntropyLoss - -from sharktank.layers import * -from sharktank.types import * - -from sharktank.models.llama.llama import * -from sharktank.models.mixtral.mixtral import * -from sharktank.models.grok.grok import * - -from sharktank.utils import cli -from sharktank.utils.load_llm import * - -log_levels = { - "info": logging.INFO, - "debug": logging.DEBUG, -} -logger = logging.getLogger("eval") - -logger.setLevel(log_levels["debug"]) - -logger.root.handlers[0].setFormatter( - logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s") -) - -__all__ = ["Perplexity", "run_perplexity"] - - -class Perplexity: - """ - Perplexity (PPL) is one of the most common metrics for evaluating language models. - It is defined as the exponentiated average negative log-likelihood of a sequence, - calculated with exponent base `e`. - - For more information, see https://huggingface.co/docs/transformers/perplexity - """ - - def __init__( - self, - prompts: list, - device, - kv_cache_type, - ): - self.prompts = prompts - self.add_start_token = False - self.batch_size = 16 - self.bs = len(prompts) - self.device = device - self.kv_cache_type = kv_cache_type - - def timeit(func): - def wrapper(*args, **kwargs): - start = time.time() - result = func(*args, **kwargs) - end = time.time() - seconds = end - start - time_taken = abs(timedelta(seconds=round(seconds))) - - if seconds < 1: - time_taken = f" {seconds * 1000} ms" - - func_name = func.__name__ - if func_name == "get_perplexity": - func_name = "Total time" - logger.info(f" {func_name}: {time_taken}") - return result - - return wrapper - - def print_token_comparison(self, i): - if i <= self.max_prompt_length: - batch_predicted_token_id = [[i[-1]] for i in self.batch.results] - batch_predicted_token = self.generator.tokenizer.decode( - batch_predicted_token_id - ) - logger.debug(f"Predicted:") - logger.debug(f"{batch_predicted_token}") - logger.debug(f"{batch_predicted_token_id}") - - expected_token_id = self.token_ids[:, i + 1 : i + 2].tolist() - expected_token = self.generator.tokenizer.decode(expected_token_id) - logger.debug(f"Expected:") - logger.debug(f"{expected_token}") - logger.debug(f"{expected_token_id}") - - @timeit - def load_model(self, dataset, tokenizer): - - theta = dataset.root_theta - - config = LlamaModelConfig( - hp=configs.LlamaHParams.from_gguf_props(dataset.properties), - block_seq_stride=16, - kv_cache_type=self.kv_cache_type, - device=self.device, - activation_dtype=torch.float32, - attention_dtype=torch.float32, - ) - - if config.hp.expert_count: - if config.hp.model_arch == "grok": - model = PagedGrokModelV1(theta, config) - else: - model = PagedMixtralModelV1(theta, config) - else: - model = PagedLlamaModelV1(theta, config) - - self.generator = TorchGenerator(model, tokenizer) - - @timeit - def get_logits(self): - - token_ids, seq_lens = self.generator.tokenizer.encode( - self.prompts, - pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, - add_start_token=self.add_start_token, - ) - - logger.info(f" Prompts:") - for idx, prompt in enumerate(self.prompts): - logger.info(f" Prompt {idx} - {prompt.encode()}\n{token_ids[idx]}") - - self.max_prompt_length = max(seq_lens) - - self.token_ids = torch.tensor(token_ids, device=self.device) - self.attention_mask = ( - (self.token_ids != 0).int().detach().clone().to(self.device) - ) - - is_first_token = True - for i in tqdm( - range(0, self.max_prompt_length - 1), - desc="eval: Calculating logits", - ): - token_batch = self.token_ids[:, : i + 1] - logger.debug(f"Prefill:") - - logger.debug("Input:") - logger.debug(f"{self.generator.tokenizer.decode(token_batch)}") - - token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens( - token_ids=token_batch.tolist(), - pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, - ) - - token_batch = torch.tensor(token_batch, device=self.device) - seq_lens_batch = torch.tensor(seq_lens_batch, device=self.device) - - self.batch = self.generator.begin_eval_batch( - token_batch=token_batch, - seq_lens_batch=seq_lens_batch, - bs=self.bs, - ) - - self.cache_state = self.batch.prefill() - self.print_token_comparison(i) - - if is_first_token: - self.out_logits = self.batch.prefill_logits[:, 0:1, :] - is_first_token = False - else: - self.out_logits = torch.cat( - (self.out_logits, self.batch.prefill_logits[:, 0:1, :]), 1 - ) - - pad_logits_shape = self.token_ids.shape[1] - self.out_logits.shape[1] - - self.pad_logits = torch.zeros( - self.out_logits.shape[0], pad_logits_shape, self.out_logits.shape[2] - ) - - self.out_logits = torch.cat((self.out_logits, self.pad_logits), 1).to( - self.device - ) - - @timeit - def compute_perplexity(self): - loss_fct = CrossEntropyLoss(reduction="none") - - ## perplexity = e ^ (sum(losses) / num_tokenized_tokens) - crossentropy_loss = ( - loss_fct(self.out_logits.transpose(1, 2), self.token_ids) - * self.attention_mask - ).sum(1) - crossentropy_loss = torch.tensor(crossentropy_loss.tolist()) - perplexity_batch = torch.exp( - crossentropy_loss / self.attention_mask.sum(1) - ).tolist() - - return { - "perplexities": perplexity_batch, - "mean_perplexity": np.mean(perplexity_batch), - } - - @timeit - def get_perplexity(self): - - self.get_logits() - - self.out_logits = self.out_logits[..., :-1, :].contiguous() - self.token_ids = self.token_ids[..., 1:].contiguous() - self.attention_mask = self.attention_mask[..., 1:].contiguous() - - assert self.token_ids.shape == self.out_logits.shape[0:2] - - logger.debug(f"Logits shape: {self.out_logits.shape}") - logger.debug(f"Token ids: {self.token_ids}, {self.token_ids.shape}") - logger.debug( - f"Logits shape: {self.attention_mask}, {self.attention_mask.shape}" - ) - - return self.compute_perplexity() - - -def run_perplexity( - prompts: list[str], - dataset, - tokenizer, - device, - kv_cache_type, -): - perplexity = Perplexity(prompts=prompts, device=device, kv_cache_type=kv_cache_type) - - perplexity.load_model(dataset, tokenizer) - ppl = perplexity.get_perplexity() - - return ppl - - -def main(argv): - parser = cli.create_parser() - parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") - parser.add_argument("--device", help="Torch device (or default)") - - cli.add_input_dataset_options(parser) - cli.add_tokenizer_options(parser) - args = cli.parse(parser, args=argv) - - device = torch.device(args.device) if args.device else None - kv_cache_type = args.kv_cache_type - dataset = cli.get_input_dataset(args) - tokenizer = cli.get_tokenizer(args) - - prompt_path = "sharktank/evaluate/data/eval_prompts.txt" - with open(prompt_path, "r") as f: - input_texts = f.read().splitlines() - - ppl = run_perplexity( - prompts=input_texts[0:1], - dataset=dataset, - tokenizer=tokenizer, - device=device, - kv_cache_type=kv_cache_type, - ) - - logger.info(f"\n{json.dumps(ppl, indent=2)}") - return ppl - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/sharktank/sharktank/utils/load_llm.py b/sharktank/sharktank/utils/load_llm.py deleted file mode 100644 index 558653d9b..000000000 --- a/sharktank/sharktank/utils/load_llm.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import math - -import torch - -from sharktank.layers import * -from sharktank.types import * -from sharktank.models.llama.llama import * - -from ..utils.debugging import trace_tensor -from ..utils.tokenizer import InferenceTokenizer - - -class TorchGenerator: - """Generator that runs directly on the Torch model.""" - - def __init__( - self, - model: PagedLlamaModelV1, - tokenizer: InferenceTokenizer, - page_cache_size: int = 8192, - # Need to look at the model more for this. - end_token: int = 2, - ): - self.model = model - self.tokenizer = tokenizer - if model.cache.is_paged: - self.shared_cache_state = model.cache.paged.allocate(page_cache_size) - else: - self.shared_cache_state = None - self.free_pages = list(range(1, 8192)) - self.end_token = end_token - - @property - def block_seq_stride(self) -> int: - return self.model.cache.block_seq_stride - - def begin_batch(self, prompts: list[str], add_start_token: bool): - token_ids, seq_lens = self.tokenizer.encode( - prompts, - pad_to_multiple_of=self.model.cache.pad_sequence_stride, - add_start_token=add_start_token, - ) - token_ids = torch.tensor(token_ids, device=self.model.device) - seq_lens = torch.tensor(seq_lens, device=self.model.device) - if self.shared_cache_state is not None: - cache_state = self.shared_cache_state - else: - cache_state = self.model.cache.direct.allocate(bs=len(prompts)) - return Batch(self, token_ids, seq_lens, cache_state) - - def begin_eval_batch( - self, - token_batch: torch.tensor, - seq_lens_batch: torch.tensor, - bs: int, - ): - - if self.shared_cache_state is not None: - cache_state = self.shared_cache_state - else: - cache_state = self.model.cache.direct.allocate(bs=bs) - return Batch(self, token_batch, seq_lens_batch, cache_state) - - def alloc_page(self) -> int: - if self.model.cache.is_direct: - # We don't allocate block ids for the direct cache. - return 0 - - return self.free_pages.pop() - - def release_page(self, index: int): - if self.model.cache.is_direct: - return - self.free_pages.append(index) - - -class Batch: - def __init__( - self, - parent: TorchGenerator, - token_ids: torch.Tensor, - seq_lens: torch.Tensor, - cache_state: list[torch.Tensor], - ): - self.bs = token_ids.shape[0] - # assert seq_lens.shape[0] == self.bs - self.parent = parent - self.token_ids = token_ids - self.seq_lens = seq_lens - self.cache_state = cache_state - self.results: list[list[int]] = [[] for _ in range(self.bs)] - self.done_result_indices: set[int] = set() - - # Assemble the batch. - seq_stride = self.parent.block_seq_stride - self.seq_block_ids: list[list[int]] = [] - for seq_len in self.seq_lens: - blocks_needed = ( - int(math.ceil(seq_len / seq_stride)) if seq_stride > 0 else 0 - ) - row = [] - for _ in range(blocks_needed): - row.append(self.parent.alloc_page()) - self.seq_block_ids.append(row) - - @property - def done(self) -> bool: - return len(self.done_result_indices) == self.bs - - def detokenize(self) -> list[str]: - return self.parent.tokenizer.decode(self.results) - - def print_current_results(self): - results = self.detokenize() - for i, s in enumerate(results): - seq_len = int(self.seq_lens[i]) - print(f" {i}({len(self.results[i])}, {seq_len}): {s}") - - def add_result_token(self, tokens: torch.Tensor): - for i in range(self.bs): - token = tokens[i][0] - if token == self.parent.end_token: - self.done_result_indices.add(i) - if i in self.done_result_indices: - continue - token = int(tokens[i, 0]) - self.results[i].append(token) - - def allocate_seq_block_ids(self): - for i in range(self.bs): - sl = int(self.seq_lens[i]) - if (sl % self.parent.block_seq_stride) == 0: - needed_blocks = sl // self.parent.block_seq_stride + 1 - else: - needed_blocks = math.ceil(sl / self.parent.block_seq_stride) - block_ids_row = self.seq_block_ids[i] - while len(block_ids_row) < needed_blocks: - block_ids_row.append(self.parent.alloc_page()) - - def prefill(self): - model = self.parent.model - attention_mask = model.attention_mask( - model.input_mask(self.seq_lens, self.token_ids.shape[1]) - ) - seq_block_ids_tensor = self.pad_block_ids() - trace_tensor("prefill.token_ids", self.token_ids) - trace_tensor("prefill.seq_block_ids", seq_block_ids_tensor) - trace_tensor("prefill.attention_mask", attention_mask) - self.prefill_logits = model.prefill( - self.token_ids, - attention_mask=attention_mask, - seq_block_ids=seq_block_ids_tensor, - cache_state=self.cache_state, - ) - - # TODO: Generalize the sampling and don't make it swap on/off cpu. - # TODO: Normalize the output of extract_tokens_from_logits into - # tensor [bs, 1]. - tokens = torch.tensor( - model.extract_tokens_from_logits(self.prefill_logits, self.seq_lens) - ).unsqueeze(1) - self.add_result_token(tokens) - self.next_tokens = tokens.to(device=model.device) - - def decode(self, token_batch): - self.token_batch = token_batch - - model = self.parent.model - start_positions = self.seq_lens.clone() - self.seq_lens.add_(1) - self.allocate_seq_block_ids() - # TODO: Allocate more blocks on overflow. - seq_block_ids_tensor = self.pad_block_ids() - decode_attention_mask = model.decode_attention_mask( - model.input_mask( - self.seq_lens, - seq_block_ids_tensor.shape[1] * self.parent.block_seq_stride, - ) - ) - trace_tensor("decode.token_ids", self.token_ids) - trace_tensor("decode.start_positions", start_positions) - trace_tensor("decode.seq_block_ids", seq_block_ids_tensor) - trace_tensor("decode.attention_mask", decode_attention_mask) - - self.decode_logits = model.decode( - self.token_batch, - attention_mask=decode_attention_mask, - start_positions=start_positions, - seq_block_ids=seq_block_ids_tensor, - cache_state=self.cache_state, - ) - - trace_tensor("decode.logits", self.decode_logits) - # # TODO: Normalize the output of extract_tokens_from_logits into - # # tensor [bs, 1]. - tokens = torch.tensor( - model.extract_tokens_from_logits(self.decode_logits, [1] * self.bs), - device=self.parent.model.device, - ).unsqueeze(1) - self.add_result_token(tokens) - self.next_tokens = tokens - - def pad_block_ids(self) -> torch.Tensor: - max_length = max(len(r) for r in self.seq_block_ids) - rows = [r + (max_length - len(r)) * [0] for r in self.seq_block_ids] - return torch.tensor(rows, device=self.parent.model.device) diff --git a/sharktank/sharktank/utils/tokenizer.py b/sharktank/sharktank/utils/tokenizer.py index f272055c9..29a57f958 100644 --- a/sharktank/sharktank/utils/tokenizer.py +++ b/sharktank/sharktank/utils/tokenizer.py @@ -22,57 +22,34 @@ class InferenceTokenizer(ABC): """Simple inference tokenizer.""" def encode( - self, - texts: list[str], - pad_to_multiple_of: int = 1, - add_start_token: bool = True, + self, texts: list[str], pad_to_multiple_of: int = 1, pad_token: int = 0 ) -> tuple[list[list[int]]]: """Encodes a list of texts into a padded list of tokens. Returns a list of list of tokens and a list of unpadded lengths. """ - raw_rows = self._encode(texts, add_start_token) - raw_rows, lengths = self.pad_tokens( - token_ids=raw_rows, pad_to_multiple_of=pad_to_multiple_of - ) - return raw_rows, lengths - - def decode(self, tokens: Union[list[list[int]]], lens: Optional[list[int]] = None): - """Decodes a list of tokens.""" - if lens is not None: - tokens = list(tokens) - for i, row_length in enumerate(lens): - tokens[i] = tokens[i][0:row_length] - return self._decode(tokens) - - def get_prompt_lengths( - self, - token_ids: list[list[int]], - ): + raw_rows = self._encode(texts) max_length = 0 lengths: list[int] = [] - for row in token_ids: + for row in raw_rows: lengths.append(len(row)) max_length = max(max_length, len(row)) - - return lengths, max_length - - def pad_tokens( - self, - token_ids: list[list[int]], - pad_to_multiple_of: int, - pad_token: int = 0, - ): - lengths, max_length = self.get_prompt_lengths(token_ids) if pad_to_multiple_of > 1: max_length = int( pad_to_multiple_of * math.ceil(max_length / pad_to_multiple_of) ) - for row in token_ids: + for row in raw_rows: pad_count = max_length - len(row) row.extend(pad_count * [pad_token]) + return raw_rows, lengths - return token_ids, lengths + def decode(self, tokens: Union[list[list[int]]], lens: Optional[list[int]] = None): + """Decodes a list of tokens.""" + if lens is not None: + tokens = list(tokens) + for i, row_length in enumerate(lens): + tokens[i] = tokens[i][0:row_length] + return self._decode(tokens) @abstractmethod def _encode(self, texts: list[str]) -> list[list[int]]: @@ -99,10 +76,9 @@ class _TransformersTokenizer(InferenceTokenizer): def __init__(self, t: AutoTokenizer): self._t = t - def _encode(self, texts: list[str], add_start_token: bool) -> list[list[int]]: + def _encode(self, texts: list[str]) -> list[list[int]]: results = t.batch_encode_plus( texts, - add_special_tokens=add_start_token, padding=False, truncation=False, ) diff --git a/sharktank/tests/evaluate/perplexity_test.py b/sharktank/tests/evaluate/perplexity_test.py deleted file mode 100644 index ab2091f06..000000000 --- a/sharktank/tests/evaluate/perplexity_test.py +++ /dev/null @@ -1,345 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import unittest -import pytest - -from sharktank.evaluate import perplexity - - -class PerplexityTest(unittest.TestCase): - def test_llama3_8B_f16_decomposed(self): - - # Llama 3.1 8B decomposed - - llama_8b_f16_gguf_path = "/data/extra/models/llama3.1_8B/llama8b_f16.gguf" - llama_8b_f16_tokenizer_path = ( - "/data/extra/models/llama3.1_8B/tokenizer_config.json" - ) - - llama_8b_perplexity = perplexity.main( - [ - f"--gguf-file={llama_8b_f16_gguf_path}", - f"--tokenizer-config-json={llama_8b_f16_tokenizer_path}", - ] - ) - - baseline_llama_8b_perplexity = { - "perplexities": [ - 9.875290870666504, - 8.075149536132812, - 16.164775848388672, - 11.06580924987793, - 11.46964168548584, - 12.714613914489746, - ], - "mean_perplexity": 11.560880184173584, - } - - delta = 5e-1 - - self.assertAlmostEqual( - baseline_llama_8b_perplexity["mean_perplexity"], - llama_8b_perplexity["mean_perplexity"], - delta=delta, - msg=f"Perplexity is deviating more than {delta}", - ) - - @pytest.mark.xfail - def test_llama3_8B_f16_non_decomposed(self): - - # Llama 3.1 8B non-decomposed - - llama_8b_f16_gguf_path = "/data/extra/models/llama3.1_8B/llama8b_f16.gguf" - llama_8b_f16_tokenizer_path = ( - "/data/extra/models/llama3.1_8B/tokenizer_config.json" - ) - - llama_8b_perplexity = perplexity.main( - [ - f"--gguf-file={llama_8b_f16_gguf_path}", - f"--tokenizer-config-json={llama_8b_f16_tokenizer_path}", - f"--attention-kernel=torch_sdpa", - ] - ) - - # dummy data - baseline_llama_8b_perplexity = { - "perplexities": [ - 9.875290870666504, - 8.075149536132812, - 16.164775848388672, - 11.06580924987793, - 11.46964168548584, - 12.714613914489746, - ], - "mean_perplexity": 11.560880184173584, - } - - delta = 5e-1 - - self.assertAlmostEqual( - baseline_llama_8b_perplexity["mean_perplexity"], - llama_8b_perplexity["mean_perplexity"], - delta=delta, - msg=f"Perplexity is deviating more than {delta}", - ) - - @pytest.mark.xfail - def test_llama3_8B_fp8_decomposed(self): - - # Llama 3.1 8B decomposed - - llama_8b_fp8_gguf_path = "/data/extra/models/llama3.1_8B/llama8b_fp8.gguf" - llama_8b_fp8_tokenizer_path = ( - "/data/extra/models/llama3.1_8B/tokenizer_config.json" - ) - - llama_8b_perplexity = perplexity.main( - [ - f"--gguf-file={llama_8b_fp8_gguf_path}", - f"--tokenizer-config-json={llama_8b_fp8_tokenizer_path}", - ] - ) - - # dummy data - baseline_llama_8b_perplexity = { - "perplexities": [ - 9.875290870666504, - 8.075149536132812, - 16.164775848388672, - 11.06580924987793, - 11.46964168548584, - 12.714613914489746, - ], - "mean_perplexity": 11.560880184173584, - } - - delta = 5e-1 - - self.assertAlmostEqual( - baseline_llama_8b_perplexity["mean_perplexity"], - llama_8b_perplexity["mean_perplexity"], - delta=delta, - msg=f"Perplexity is deviating more than {delta}", - ) - - @pytest.mark.xfail - def test_llama3_8B_fp8_non_decomposed(self): - - # Llama 3.1 8B non-decomposed - - llama_8b_fp8_gguf_path = "/data/extra/models/llama3.1_8B/llama8b_fp8.gguf" - llama_8b_fp8_tokenizer_path = ( - "/data/extra/models/llama3.1_8B/tokenizer_config.json" - ) - - llama_8b_perplexity = perplexity.main( - [ - f"--gguf-file={llama_8b_fp8_gguf_path}", - f"--tokenizer-config-json={llama_8b_fp8_tokenizer_path}", - f"--attention-kernel=torch_sdpa", - ] - ) - - # dummy data - baseline_llama_8b_perplexity = { - "perplexities": [ - 9.875290870666504, - 8.075149536132812, - 16.164775848388672, - 11.06580924987793, - 11.46964168548584, - 12.714613914489746, - ], - "mean_perplexity": 11.560880184173584, - } - - delta = 5e-1 - - self.assertAlmostEqual( - baseline_llama_8b_perplexity["mean_perplexity"], - llama_8b_perplexity["mean_perplexity"], - delta=delta, - msg=f"Perplexity is deviating more than {delta}", - ) - - def test_llama3_405B_f16_decomposed(self): - - # Llama 3.1 405B decomposed - - llama_405b_f16_gguf_path = ( - "/data/extra/models/llama3.1_405B/llama405b_fp16.gguf" - ) - llama_405b_f16_tokenizer_path = ( - "/data/extra/models/llama3.1_405B/tokenizer_config.json" - ) - - tensor_parallelism_size = 8 - - llama_405b_perplexity = perplexity.main( - [ - f"--gguf-file={llama_405b_f16_gguf_path}", - f"--tokenizer-config-json={llama_405b_f16_tokenizer_path}", - f"--tensor-parallelism-size={tensor_parallelism_size}", - ] - ) - - # dummy data - baseline_llama_405b_perplexity = { - "perplexities": [ - 9.875290870666504, - 8.075149536132812, - 16.164775848388672, - 11.06580924987793, - 11.46964168548584, - 12.714613914489746, - ], - "mean_perplexity": 11.560880184173584, - } - - delta = 5e-1 - - self.assertAlmostEqual( - baseline_llama_405b_perplexity["mean_perplexity"], - llama_405b_perplexity["mean_perplexity"], - delta=delta, - msg=f"Perplexity is deviating more than {delta}", - ) - - @pytest.mark.xfail - def test_llama3_405B_f16_non_decomposed(self): - - # Llama 3.1 405B non-decomposed - - llama_405b_f16_gguf_path = ( - "/data/extra/models/llama3.1_405B/llama405b_fp16.gguf" - ) - llama_405b_f16_tokenizer_path = ( - "/data/extra/models/llama3.1_405B/tokenizer_config.json" - ) - - tensor_parallelism_size = 8 - - llama_405b_perplexity = perplexity.main( - [ - f"--gguf-file={llama_405b_f16_gguf_path}", - f"--tokenizer-config-json={llama_405b_f16_tokenizer_path}", - f"--tensor-parallelism-size={tensor_parallelism_size}", - f"--attention-kernel=torch_sdpa", - ] - ) - - # dummy data - baseline_llama_405b_perplexity = { - "perplexities": [ - 9.875290870666504, - 8.075149536132812, - 16.164775848388672, - 11.06580924987793, - 11.46964168548584, - 12.714613914489746, - ], - "mean_perplexity": 11.560880184173584, - } - - delta = 5e-1 - - self.assertAlmostEqual( - baseline_llama_405b_perplexity["mean_perplexity"], - llama_405b_perplexity["mean_perplexity"], - delta=delta, - msg=f"Perplexity is deviating more than {delta}", - ) - - @pytest.mark.xfail - def test_llama3_405B_fp8_decomposed(self): - - # Llama 3.1 405B decomposed - - llama_405b_fp8_gguf_path = "/data/extra/models/llama3.1_405B/llama405b_fp8.gguf" - llama_405b_fp8_tokenizer_path = ( - "/data/extra/models/llama3.1_405B/tokenizer_config.json" - ) - - tensor_parallelism_size = 8 - - llama_405b_perplexity = perplexity.main( - [ - f"--gguf-file={llama_405b_fp8_gguf_path}", - f"--tokenizer-config-json={llama_405b_fp8_tokenizer_path}", - f"--tensor-parallelism-size={tensor_parallelism_size}", - ] - ) - - # dummy data - baseline_llama_405b_perplexity = { - "perplexities": [ - 9.875290870666504, - 8.075149536132812, - 16.164775848388672, - 11.06580924987793, - 11.46964168548584, - 12.714613914489746, - ], - "mean_perplexity": 11.560880184173584, - } - - delta = 5e-1 - - self.assertAlmostEqual( - baseline_llama_405b_perplexity["mean_perplexity"], - llama_405b_perplexity["mean_perplexity"], - delta=delta, - msg=f"Perplexity is deviating more than {delta}", - ) - - @pytest.mark.xfail - def test_llama3_405B_fp8_non_decomposed(self): - - # Llama 3.1 405B non-decomposed - - llama_405b_fp8_gguf_path = "/data/extra/models/llama3.1_405B/llama405b_fp8.gguf" - llama_405b_fp8_tokenizer_path = ( - "/data/extra/models/llama3.1_405B/tokenizer_config.json" - ) - - tensor_parallelism_size = 8 - - llama_405b_perplexity = perplexity.main( - [ - f"--gguf-file={llama_405b_fp8_gguf_path}", - f"--tokenizer-config-json={llama_405b_fp8_tokenizer_path}", - f"--tensor-parallelism-size={tensor_parallelism_size}", - f"--attention-kernel=torch_sdpa", - ] - ) - - # dummy data - baseline_llama_405b_perplexity = { - "perplexities": [ - 9.875290870666504, - 8.075149536132812, - 16.164775848388672, - 11.06580924987793, - 11.46964168548584, - 12.714613914489746, - ], - "mean_perplexity": 11.560880184173584, - } - - delta = 5e-1 - - self.assertAlmostEqual( - baseline_llama_405b_perplexity["mean_perplexity"], - llama_405b_perplexity["mean_perplexity"], - delta=delta, - msg=f"Perplexity is deviating more than {delta}", - ) - - -if __name__ == "__main__": - unittest.main() From e700bfa78f29888cbd881986708e26898f7a7ec7 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Wed, 16 Oct 2024 15:29:19 -0400 Subject: [PATCH 2/3] [tuner] Update gpu pipeline option handling (#282) Use the attribute format introduced in https://github.com/iree-org/iree/pull/18458. Fixes: https://github.com/nod-ai/SHARK-Platform/issues/186 --- tuner/tuner/candidate_gen.py | 60 +++++++++++++++++--- tuner/tuner/candidate_gen_test.py | 92 ++++++++++++++++++++++++++----- 2 files changed, 131 insertions(+), 21 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 16f0cf724..40eb27a82 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -24,10 +24,10 @@ import pickle import re import z3 -from dataclasses import asdict, dataclass +from dataclasses import astuple, dataclass from enum import Enum from os import mkdir, path, makedirs -from typing import Callable, Optional +from typing import Optional from textwrap import indent from abc import ABC, abstractmethod @@ -148,6 +148,44 @@ def all(): ] +class ReorderWorkgroupsStrategy(Enum): + NONE = 0 + SWIZZLE = 1 + TRANSPOSE = 2 + + def __str__(self) -> str: + return self.name.title() + + +@dataclass +class GpuPipelineOptions: + """Represents the `iree_gpu.pipeline_options` attribute""" + + prefetch_shared_memory: Optional[bool] = None + no_reduce_shared_memory_bank_conflicts: Optional[bool] = None + reorder_workgroups_strategy: Optional[ReorderWorkgroupsStrategy] = None + + def all_default(self) -> bool: + return all(x is None for x in astuple(self)) + + def __str__(self) -> str: + options: list[str] = [] + if self.prefetch_shared_memory is not None: + options.append( + f"prefetch_shared_memory = {str(self.prefetch_shared_memory).lower()}" + ) + if self.no_reduce_shared_memory_bank_conflicts is not None: + options.append( + f"no_reduce_shared_memory_bank_conflicts = {str(self.no_reduce_shared_memory_bank_conflicts).lower()}" + ) + if self.reorder_workgroups_strategy is not None: + options.append( + f"reorder_workgroups_strategy = {self.reorder_workgroups_strategy}" + ) + + return f"#iree_gpu.pipeline_options<{', '.join(options)}>" + + @dataclass class Configuration: subgroup_size: int @@ -156,6 +194,7 @@ class Configuration: tile_sizes: list[int] subgroup_m_count: int subgroup_n_count: int + gpu_pipeline_options: GpuPipelineOptions waves_per_eu: int @@ -223,7 +262,9 @@ def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: def get_pipeline_config(configuration: Configuration) -> str: - extra_config = ", prefetch_shared_memory" + extra_config = "" + if not configuration.gpu_pipeline_options.all_default(): + extra_config += f", gpu_pipeline_options = {configuration.gpu_pipeline_options}" if configuration.waves_per_eu != 2: extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}' return extra_config @@ -234,17 +275,19 @@ def apply_configuration( ) -> str: tune_logger.info(f"Applying: {configuration}") expr0 = re.compile( - r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" + r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" ) expr1 = re.compile( r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+)," ) expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") - expr3 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") + expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") + expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") repl0 = f", subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>" repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' - repl3 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' + repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" + repl4 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' new_mlir = "" for line in template: @@ -254,8 +297,10 @@ def apply_configuration( line = re.sub(expr1, repl1, line) if "tile_sizes" in line: line = re.sub(expr2, repl2, line) - if "amdgpu-waves-per-eu" in line: + if "gpu_pipeline_options =" in line: line = re.sub(expr3, repl3, line) + if "amdgpu-waves-per-eu" in line: + line = re.sub(expr4, repl4, line) new_mlir += line return new_mlir @@ -461,6 +506,7 @@ def generate_solutions(problem_size: ProblemSize, num_subgrups: int): [lookup(m), lookup(n), lookup(k)], lookup(sg_m_cnt), lookup(sg_n_cnt), + GpuPipelineOptions(), lookup(waves_per_eu), ) solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 392f8bc06..2924db75b 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -67,6 +67,7 @@ def test_get_mmt_tile_sizes(): tile_sizes=[128, 320, 32], subgroup_m_count=0, subgroup_n_count=0, + gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), waves_per_eu=0, ) assert candidate_gen.get_mmt_tile_sizes(config) == [128, 320, 32] @@ -80,6 +81,7 @@ def test_get_conv_tile_sizes(): tile_sizes=[464, 320, 16], subgroup_m_count=1, subgroup_n_count=4, + gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), waves_per_eu=1, ) assert candidate_gen.ConvTuner().get_conv_tile_sizes(config) == [ @@ -93,6 +95,32 @@ def test_get_conv_tile_sizes(): ] +def test_gpu_pipeline_options(): + options = candidate_gen.GpuPipelineOptions() + assert options.all_default() + assert str(options) == "#iree_gpu.pipeline_options<>" + + options.prefetch_shared_memory = True + assert not options.all_default() + assert str(options) == "#iree_gpu.pipeline_options" + + options.no_reduce_shared_memory_bank_conflicts = False + assert ( + str(options) + == "#iree_gpu.pipeline_options" + ) + + options = candidate_gen.GpuPipelineOptions() + options.reorder_workgroups_strategy = ( + candidate_gen.ReorderWorkgroupsStrategy.TRANSPOSE + ) + assert not options.all_default() + assert ( + str(options) + == "#iree_gpu.pipeline_options" + ) + + def test_get_contract_tile_sizes(): config = candidate_gen.Configuration( subgroup_size=32, @@ -101,6 +129,7 @@ def test_get_contract_tile_sizes(): tile_sizes=[4, 8, 16], subgroup_m_count=1, subgroup_n_count=1, + gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), waves_per_eu=2, ) assert candidate_gen.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16] @@ -114,28 +143,28 @@ def test_get_contract_tile_sizes(): def test_get_pipeline_config(): - config1 = candidate_gen.Configuration( + config = candidate_gen.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], intrinsic="", tile_sizes=[4, 8, 16], subgroup_m_count=1, subgroup_n_count=1, + gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), waves_per_eu=2, ) - config2 = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic="", - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - waves_per_eu=4, - ) - assert candidate_gen.get_pipeline_config(config1) == ", prefetch_shared_memory" + config1_str: str = candidate_gen.get_pipeline_config(config) + assert config1_str == "" + + config.waves_per_eu = 4 + config2_str: str = candidate_gen.get_pipeline_config(config) + assert config2_str == ', llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' + + config.gpu_pipeline_options.prefetch_shared_memory = True + config3_str = candidate_gen.get_pipeline_config(config) assert ( - candidate_gen.get_pipeline_config(config2) - == ', prefetch_shared_memory, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' + config3_str + == ', gpu_pipeline_options = #iree_gpu.pipeline_options, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' ) @@ -409,11 +438,18 @@ def test_generate_constraints_invalid_input(): assert solver.check() == candidate_gen.z3.unsat +def remove_comments(mlir: str) -> str: + return "\n".join( + filter(lambda x: not x.lstrip().startswith("//"), mlir.splitlines()) + ) + + def test_apply_params_mmt(): mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", "", + "gpu_pipeline_options = #iree_gpu.pipeline_options", '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', ] @@ -426,6 +462,9 @@ def test_apply_params_mmt(): tile_sizes=[8, 8, 8], subgroup_m_count=16, subgroup_n_count=16, + gpu_pipeline_options=candidate_gen.GpuPipelineOptions( + prefetch_shared_memory=True + ), waves_per_eu=8, ) @@ -442,6 +481,7 @@ def test_apply_params_mmt(): embeddable = tf_mlir.embeddable assert modified + modified = remove_comments(modified) assert embeddable assert ( "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 16, subgroup_n_count = 16" @@ -452,6 +492,10 @@ def test_apply_params_mmt(): in modified ) assert "tile_sizes = [[8, 8, 8]]" in modified + assert ( + "gpu_pipeline_options = #iree_gpu.pipeline_options" + in modified + ) assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in modified @@ -460,7 +504,7 @@ def test_apply_params_conv(): ", subgroup_m_count = 16, subgroup_n_count = 16>", "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', + 'gpu_pipeline_options = #iree_gpu.pipeline_options, {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', ] n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 @@ -472,6 +516,9 @@ def test_apply_params_conv(): tile_sizes=[464, 320, 16], subgroup_m_count=1, subgroup_n_count=4, + gpu_pipeline_options=candidate_gen.GpuPipelineOptions( + reorder_workgroups_strategy=candidate_gen.ReorderWorkgroupsStrategy.TRANSPOSE + ), waves_per_eu=2, ) @@ -492,6 +539,8 @@ def test_apply_params_conv(): embeddable = tf_mlir.embeddable assert modified + modified = remove_comments(modified) + assert embeddable assert ( "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" @@ -502,6 +551,10 @@ def test_apply_params_conv(): in modified ) assert "tile_sizes = [[1, 1, 464, 320, 1, 1, 16]]" in modified + assert ( + "gpu_pipeline_options = #iree_gpu.pipeline_options" + in modified + ) assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified @@ -529,6 +582,7 @@ def test_apply_params_contract(): tile_sizes=[480, 384, 32], subgroup_m_count=1, subgroup_n_count=4, + gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), waves_per_eu=2, ) @@ -575,6 +629,7 @@ def test_apply_params_batch_matmul(): tile_sizes=[416, 320, 128], subgroup_m_count=2, subgroup_n_count=2, + gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), waves_per_eu=2, ) @@ -586,6 +641,8 @@ def test_apply_params_batch_matmul(): embeddable = tf_mlir.embeddable assert modified + modified = remove_comments(modified) + assert embeddable assert ( "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" @@ -622,6 +679,7 @@ def test_apply_params_batch_mmt_float(): tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, + gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), waves_per_eu=2, ) @@ -669,6 +727,7 @@ def test_apply_params_batch_mmt_int(): tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, + gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), waves_per_eu=4, ) @@ -681,6 +740,8 @@ def test_apply_params_batch_mmt_int(): assert modified assert "// transform.named_sequence @match_batch_mmt_2x4096x640x640(" in modified + modified = remove_comments(modified) + assert ( "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" in modified @@ -737,6 +798,7 @@ def test_apply_params_broadcast_rhs_mmt(): tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, + gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), waves_per_eu=4, ) @@ -752,6 +814,8 @@ def test_apply_params_broadcast_rhs_mmt(): "// transform.named_sequence @match_broadcast_rhs_mmt_Bx4096x640x640(" in modified ) + modified = remove_comments(modified) + assert ( "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" in modified From f5fcd007350e23cadc0ad793e3f91d6a7eed28ea Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 16 Oct 2024 12:56:13 -0700 Subject: [PATCH 3/3] Refactor llama / mixtral / grok for shared features (#267) Many of these features can toggle between depending on architecture. Replumbing the configurations separately allows better reuse and understanding of how models vary between eachother. grok uses a softcap, plumbing a value enables `sc * tanh( v / sc)` grok has some hardcoded values that have better representations, e.g. `sqrt(6144)` and `sqrt(3)`. output normalization is optional but used by mixtral. Presence of the tensor is sufficient for performing the normalization. --- .../sharktank/export_layer/export_moe.py | 13 +- sharktank/sharktank/layers/__init__.py | 2 +- sharktank/sharktank/layers/ffn_moe_block.py | 12 +- .../sharktank/layers/llama_attention_block.py | 6 +- .../layers/mixture_of_experts_block.py | 114 ++---------------- .../layers/paged_llama_attention_block.py | 36 +++--- sharktank/sharktank/models/grok/grok.py | 87 ++++++------- sharktank/sharktank/models/llama/llama.py | 1 - sharktank/sharktank/models/llama/llama_ref.py | 5 +- sharktank/sharktank/models/mixtral/mixtral.py | 77 ++++++------ .../sharktank/models/mixtral/mixtral_ref.py | 42 ++++--- sharktank/sharktank/utils/tokenizer.py | 1 + .../tests/models/llama/moe_block_test.py | 7 +- 13 files changed, 159 insertions(+), 244 deletions(-) diff --git a/sharktank/sharktank/export_layer/export_moe.py b/sharktank/sharktank/export_layer/export_moe.py index f2c10c4b4..af4ed51d2 100644 --- a/sharktank/sharktank/export_layer/export_moe.py +++ b/sharktank/sharktank/export_layer/export_moe.py @@ -5,9 +5,12 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import torch +import torch.nn.functional as F + from iree.turbine.aot import * + from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch -from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock +from sharktank.layers.mixture_of_experts_block import MoeBlock from ..utils import cli @@ -37,8 +40,8 @@ def main(): action="store_true", ) parser.add_argument( - "--use-grok", - help="Enable to export Grok model's version of MOE block", + "--use-gelu", + help="Enable to use gelu for moe activation", action="store_true", ) @@ -46,12 +49,12 @@ def main(): bs = args.batch_size - model = PreGatherMoeBlock( + model = MoeBlock( theta=make_moe_block_theta()("blk.0"), expert_count=8, expert_used_count=2, rms_epsilon=1e-5, - use_grok=args.use_grok, + moe_activation=F.gelu if args.use_gelu else F.silu, ) fxb = FxProgramsBuilder(model) input = make_rand_torch((bs, 32, 6144)) diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py index a90def3a9..fd56ec872 100644 --- a/sharktank/sharktank/layers/__init__.py +++ b/sharktank/sharktank/layers/__init__.py @@ -16,6 +16,6 @@ from .paged_llama_attention_block import PagedLlamaAttentionBlock from .ffn_block import FFN from .ffn_moe_block import FFNMOE -from .mixture_of_experts_block import SparseMoeBlock, PreGatherMoeBlock +from .mixture_of_experts_block import MoeBlock from .configs import * diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py index 0536302cf..0746f0fa0 100644 --- a/sharktank/sharktank/layers/ffn_moe_block.py +++ b/sharktank/sharktank/layers/ffn_moe_block.py @@ -12,7 +12,7 @@ from .base import ThetaLayer from .linear import LinearLayer from ..types import Theta, DefaultPrimitiveTensor -from ..ops import einsum_2args +from ..ops import einsum_2args, elementwise __all__ = [ "FFNMOE", @@ -24,15 +24,15 @@ class PreGatherFFNMOE(ThetaLayer): def __init__( self, theta: Theta, - use_grok: bool = False, + activation=F.silu, ): super().__init__(theta) - self.use_grok = use_grok self.ffn_gate = theta.tensor("ffn_gate_exps", "weight") self.ffn_up = theta.tensor("ffn_up_exps", "weight") self.ffn_down = theta.tensor("ffn_down_exps", "weight") + self.activation = activation def pre_matmul_gather(self, inputs, weights, experts, einstring="mk,menk->men"): inputs = inputs[:, :] @@ -63,10 +63,8 @@ def forward( experts: torch.Tensor, expert_gate: torch.Tensor, ): - if self.use_grok: - ffn_gate = F.gelu(self.pre_matmul_gather(h, self.ffn_gate, experts)) - else: - ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate, experts)) + ffn_gate = self.pre_matmul_gather(h, self.ffn_gate, experts) + ffn_gate = elementwise(self.activation, ffn_gate) ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts) ffn_down = self.pre_matmul_gather( diff --git a/sharktank/sharktank/layers/llama_attention_block.py b/sharktank/sharktank/layers/llama_attention_block.py index 7be8c7366..0cdb5d713 100644 --- a/sharktank/sharktank/layers/llama_attention_block.py +++ b/sharktank/sharktank/layers/llama_attention_block.py @@ -6,8 +6,6 @@ from typing import Optional -import math - import torch import torch.nn.functional as F @@ -110,7 +108,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: values = values.transpose(1, 2) # Flash attention. - attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / torch.sqrt( + self.head_dim + ) # Apply attention mask. if attention_mask is not None: diff --git a/sharktank/sharktank/layers/mixture_of_experts_block.py b/sharktank/sharktank/layers/mixture_of_experts_block.py index f788d06f0..ddce16c55 100644 --- a/sharktank/sharktank/layers/mixture_of_experts_block.py +++ b/sharktank/sharktank/layers/mixture_of_experts_block.py @@ -16,12 +16,11 @@ from .ffn_moe_block import FFNMOE, PreGatherFFNMOE __all__ = [ - "SparseMoeBlock", - "PreGatherMoeBlock", + "MoeBlock", ] -class SparseMoeBlock(ThetaLayer): +class MoeBlock(ThetaLayer): """ This implementation considers MoE operations as block-sparse operations to support imbalanced token assignments to experts. @@ -35,108 +34,12 @@ def __init__( expert_count: int, expert_used_count: int, rms_epsilon: float, - ): - super().__init__(theta) - - # Add router gate - self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp"))) - - # Add FFN norm - self.add_module( - "ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon) - ) - - # Add FFN output norm - self.add_module( - "layer_output_norm", - RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon), - ) - - # Add expert_count x FFN - self.experts = nn.ModuleList( - [FFNMOE(theta, expert_idx=i) for i in range(expert_count)] - ) - - self.expert_count = expert_count - self.expert_used_count = expert_used_count - - def forward( - self, - h: torch.Tensor, - ): - ffn_input = self.ffn_norm(h) - batch_size, sequence_length, feature_dim = ffn_input.shape - ffn_input = ffn_input.view(-1, feature_dim) - - # For each token, the router calculates the router weights for all experts - # router_logits: (batch_size * sequence_length, expert_count) - router_logits = self.ffn_gate_inp(ffn_input) - router_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - - # Select top k experts from router weights - router_weights, top_k_experts = torch.topk( - router_weights, self.expert_used_count, dim=-1 - ) - router_weights /= router_weights.sum(dim=-1, keepdim=True) - router_weights = router_weights.to(ffn_input.dtype) - - moe_output = torch.zeros( - (batch_size * sequence_length, feature_dim), dtype=ffn_input.dtype - ) - - # Create an expert mask by one hot encoding the selected top k experts - # used to index which expert is to be invoked for each token - # expert_mask: (expert_count, expert_used_count, sequence_length) - expert_mask = F.one_hot(top_k_experts, num_classes=self.expert_count).permute( - 2, 1, 0 - ) - - # Iterate over all experts in the model - for expert_idx in range(self.expert_count): - expert_layer = self.experts[expert_idx] - top_k_expert_idx, token_idx = torch.where(expert_mask[expert_idx]) - - # Given the hidden states, index the tokens assigned to this expert - # and calculate the current expert's hidden state and weigh the - # output expert hidden states by the router weights, based on the - # appropriate tokens - current_expert_tokens = ffn_input[None, token_idx] - - current_expert = ( - expert_layer(current_expert_tokens) - * router_weights[token_idx, top_k_expert_idx, None] - ) - - current_expert = current_expert.reshape(-1, feature_dim) - - moe_output.index_add_(0, token_idx, current_expert.to(ffn_input.dtype)) - moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) - - moe_output = self.layer_output_norm(moe_output) - return h + moe_output - - -class PreGatherMoeBlock(ThetaLayer): - """ - This implementation considers MoE operations as block-sparse - operations to support imbalanced token assignments to experts. - This enables the MoE to operate at a faster rate and in full capacity without any dropped tokens - (or reduced performance). - """ - - def __init__( - self, - theta: Theta, - expert_count: int, - expert_used_count: int, - rms_epsilon: float, - use_grok: Optional[bool] = False, + moe_activation=F.silu, ): super().__init__(theta) self.expert_count = expert_count self.expert_used_count = expert_used_count - self.use_grok = use_grok # Add router gate self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp"))) @@ -146,15 +49,17 @@ def __init__( "ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon) ) - # Add FFN output norm layer for Grok - if self.use_grok: + # Add optional FFN output norm layer + if theta.optional_tensor("layer_output_norm") is not None: self.add_module( "layer_output_norm", RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon), ) + else: + self.add_module("layer_output_norm", torch.nn.Identity()) # Add expert_count x FFN - self.experts = PreGatherFFNMOE(theta, use_grok=self.use_grok) + self.experts = PreGatherFFNMOE(theta, activation=moe_activation) def forward( self, @@ -180,7 +85,6 @@ def forward( moe_output = self.experts(ffn_input, top_k_experts, expert_gate) moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) - if self.use_grok: - moe_output = self.layer_output_norm(moe_output) + moe_output = self.layer_output_norm(moe_output) return h + moe_output diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 123c43be7..5bc045608 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -37,7 +37,8 @@ def __init__( head_dim: int, head_count_kv: int, rms_epsilon: float, - use_grok: Optional[bool] = False, + attention_scale: Optional[float] = None, + softcap: Optional[float] = None, ): super().__init__(theta) @@ -46,7 +47,8 @@ def __init__( self.head_count = head_count self.head_dim = head_dim self.head_count_kv = head_count_kv - self.use_grok = use_grok + self.attention_scale = attention_scale + self.softcap = softcap self.add_module( "attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon) @@ -56,7 +58,12 @@ def __init__( self.add_module("attn_v", LinearLayer(theta("attn_v"))) self.add_module("attn_output", LinearLayer(theta("attn_output"))) - if self.use_grok: + if theta.optional_tensor("attn_output_norm") is None: + self.add_module( + "attn_output_norm", + torch.nn.Identity(), + ) + else: self.add_module( "attn_output_norm", RMSNormLayer(theta("attn_output_norm"), epsilon=rms_epsilon), @@ -147,16 +154,16 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: keys = xk.transpose(1, 2) values = xv.transpose(1, 2) + attn_weights = ops.matmul(xq, keys.transpose(2, 3)) + if self.attention_scale is None: + attn_weights = attn_weights / math.sqrt(self.head_dim) + else: + attn_weights = attn_weights * self.attention_scale + # Flash attention. - if not self.use_grok: - attn_weights = ops.matmul(xq, keys.transpose(2, 3)) / math.sqrt( - self.head_dim - ) - elif self.use_grok: - attn_weights = ops.matmul(xq, keys.transpose(2, 3)) - attn_weights = 30.0 * torch.tanh( - attn_weights * (0.08838834764831845 / 30.0) - ) + if self.softcap is not None: + attn_weights = self.softcap * torch.tanh(attn_weights / self.softcap) + self.assert_not_nan(attn_weights) # Apply attention mask. @@ -172,12 +179,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: # Project. attn_output = self.attn_output(attn_output) - - if self.use_grok: - attn_output = self.attn_output_norm(attn_output) + attn_output = self.attn_output_norm(attn_output) h = h + attn_output - return h def transact_cache_direct( diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index debeb30c7..077e4e064 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -4,9 +4,11 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import math import torch import torch.nn as nn +import torch.nn.functional as F from ...layers import * @@ -82,6 +84,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() for n in range(hp.block_count): self.attn_blocks.append( @@ -93,16 +96,16 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): head_dim=hp.attn_head_dim, head_count_kv=hp.attention_head_count_kv, rms_epsilon=hp.attention_layer_norm_rms_epsilon, - use_grok=True, + softcap=30.0, # https://github.com/xai-org/grok-1/blob/7050ed204b8206bb8645c7b7bbef7252f79561b0/model.py#L864 ) ) - self.attn_blocks.append( - PreGatherMoeBlock( + self.moe_blocks.append( + MoeBlock( theta("blk", n), expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, rms_epsilon=hp.attention_layer_norm_rms_epsilon, - use_grok=True, + moe_activation=F.gelu, ) ) @@ -122,33 +125,32 @@ def prefill( self._assert_device(seq_block_ids) self._assert_device(*cache_state, dtype=self.activation_dtype) h = self.token_embedding(tokens) - h *= 78.38367176906169 + h *= math.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"grok.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - embedding=self.attention_embedding, - start_index=0, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - ) - self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "PreGatherMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) + h = attn_block( + h, + embedding=self.attention_embedding, + start_index=0, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + ) + self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) + + h = moe_block(h) + self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) - logits = logits * 0.5773502691896257 + logits = logits / math.sqrt(3.0) return logits def decode( @@ -200,34 +202,33 @@ def decode( ) h = self.token_embedding(tokens) - h *= 78.38367176906169 + h *= math.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"grok.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - start_positions=start_positions, - embedding=self.attention_embedding, - embedding_batch_mask=embedding_batch_mask, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - xk_temp=xk_temp, - xv_temp=xv_temp, - ) - self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "PreGatherMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) + h = attn_block( + h, + start_positions=start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=embedding_batch_mask, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) + + h = moe_block(h) + self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) - logits = logits * 0.5773502691896257 + logits = logits / math.sqrt(3.0) return logits diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index c324a79d5..344976ead 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -7,7 +7,6 @@ from typing import Optional from dataclasses import dataclass -import math from typing import Union import torch diff --git a/sharktank/sharktank/models/llama/llama_ref.py b/sharktank/sharktank/models/llama/llama_ref.py index 74ed9e8e0..9f77daa40 100644 --- a/sharktank/sharktank/models/llama/llama_ref.py +++ b/sharktank/sharktank/models/llama/llama_ref.py @@ -7,7 +7,6 @@ from typing import Optional from dataclasses import dataclass -import math import torch import torch.nn as nn @@ -230,7 +229,9 @@ def forward( values = values.transpose(1, 2) # Flash attention. - attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / torch.sqrt( + self.head_dim + ) # Apply attention mask. if attention_mask is not None: diff --git a/sharktank/sharktank/models/mixtral/mixtral.py b/sharktank/sharktank/models/mixtral/mixtral.py index 1fc86f87d..e2995dfde 100644 --- a/sharktank/sharktank/models/mixtral/mixtral.py +++ b/sharktank/sharktank/models/mixtral/mixtral.py @@ -85,6 +85,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() for n in range(hp.block_count): self.attn_blocks.append( @@ -98,8 +99,8 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): rms_epsilon=hp.attention_layer_norm_rms_epsilon, ) ) - self.attn_blocks.append( - SparseMoeBlock( + self.moe_blocks.append( + MoeBlock( theta("blk", n), expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, @@ -126,25 +127,26 @@ def prefill( self.trace_tensor("mixtral.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - embedding=self.attention_embedding, - start_index=0, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + h = attn_block( + h, + embedding=self.attention_embedding, + start_index=0, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = moe_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) @@ -202,28 +204,29 @@ def decode( self.trace_tensor("mixtral.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - start_positions=start_positions, - embedding=self.attention_embedding, - embedding_batch_mask=embedding_batch_mask, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - xk_temp=xk_temp, - xv_temp=xv_temp, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + h = attn_block( + h, + start_positions=start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=embedding_batch_mask, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = moe_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) diff --git a/sharktank/sharktank/models/mixtral/mixtral_ref.py b/sharktank/sharktank/models/mixtral/mixtral_ref.py index 392f60a25..70a9b9cf8 100644 --- a/sharktank/sharktank/models/mixtral/mixtral_ref.py +++ b/sharktank/sharktank/models/mixtral/mixtral_ref.py @@ -66,6 +66,7 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig): self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() for n in range(hp.block_count): self.attn_blocks.append( @@ -78,8 +79,8 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig): rms_epsilon=hp.attention_layer_norm_rms_epsilon, ) ) - self.attn_blocks.append( - SparseMoeBlock( + self.moe_blocks.append( + MoeBlock( theta("blk", n), expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, @@ -130,28 +131,29 @@ def forward( block_count = len(self.attn_blocks) // 2 # print('local_kv_cache, #attn_blocks', len(local_kv_cache), block_count) # Iterate over attention + MoE blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): # print("block_idx, block", block_idx, block) if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "LlamaAttentionBlock": - attn_block_idx = block_idx // 2 - block_cache_k = local_kv_cache[attn_block_idx] - block_cache_v = local_kv_cache[block_count + attn_block_idx] - h = block( - h, - cache_k=block_cache_k, - cache_v=block_cache_v, - start_index=start_index, - attention_mask=attention_mask, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + attn_block_idx = block_idx // 2 + block_cache_k = local_kv_cache[attn_block_idx] + block_cache_v = local_kv_cache[block_count + attn_block_idx] + h = attn_block( + h, + cache_k=block_cache_k, + cache_v=block_cache_v, + start_index=start_index, + attention_mask=attention_mask, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = attn_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) diff --git a/sharktank/sharktank/utils/tokenizer.py b/sharktank/sharktank/utils/tokenizer.py index 29a57f958..a6b0980a0 100644 --- a/sharktank/sharktank/utils/tokenizer.py +++ b/sharktank/sharktank/utils/tokenizer.py @@ -31,6 +31,7 @@ def encode( raw_rows = self._encode(texts) max_length = 0 lengths: list[int] = [] + raw_rows = [row[1:] for row in raw_rows] for row in raw_rows: lengths.append(len(row)) max_length = max(max_length, len(row)) diff --git a/sharktank/tests/models/llama/moe_block_test.py b/sharktank/tests/models/llama/moe_block_test.py index edf1d9d97..9b3daabdf 100644 --- a/sharktank/tests/models/llama/moe_block_test.py +++ b/sharktank/tests/models/llama/moe_block_test.py @@ -10,18 +10,17 @@ import torch from iree.turbine.aot import * from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch -from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock +from sharktank.layers.mixture_of_experts_block import MoeBlock from sharktank import ops -class SparseMoeBlockTest(unittest.TestCase): +class MoeBlockTest(unittest.TestCase): def test(self): - model = PreGatherMoeBlock( + model = MoeBlock( theta=make_moe_block_theta()("blk.0"), expert_count=8, expert_used_count=2, rms_epsilon=1e-5, - use_grok=False, ) fxb = FxProgramsBuilder(model) input = make_rand_torch((2, 32, 6144))