From 15fae20417783ebb20606bc9ebe86f19198d8ea5 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 31 Oct 2024 06:25:24 -0500 Subject: [PATCH] Add IREE numerics test for Llama 3.1 8B FP16 TP8 Introduce a Llama 3.1 8B FP16 TP8 test that appears to not have good numerical accuracy. It is compared to an fp64 unsharded torch variant to ensure that the reference is of high accuracy. Refactor the sharded Llama tests. Increase code reuse and use the TorchGenerator in the toy-sized tests. Use the shard_llm_dataset and export_paged_llm_v1 scripts in the test flow to increase their test coverage. --- .../sharktank/examples/export_paged_llm_v1.py | 13 +- sharktank/sharktank/layers/kv_cache.py | 5 +- sharktank/sharktank/models/llama/llama.py | 23 - sharktank/sharktank/utils/testing.py | 33 + sharktank/sharktank/utils/tokenizer.py | 2 +- .../tests/models/llama/sharded_llama_test.py | 723 ++++++++++-------- 6 files changed, 435 insertions(+), 364 deletions(-) diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 84b174bba..93f60f8f8 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -14,15 +14,12 @@ from sharktank.layers import * from sharktank.types import * -# TODO: Should be using a base class with the protocol supported. from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 -from ..models.llama.sharding import shard_theta from ..models.mixtral.mixtral import * from ..models.grok.grok import * -from .. import ops -def main(): +def main(raw_args: list[str] | None = None): from ..utils import cli parser = cli.create_parser() @@ -60,7 +57,7 @@ def main(): choices=["decomposed", "torch"], ) - args = cli.parse(parser) + args = cli.parse(parser, args=raw_args) dataset_type = cli.get_input_data_files(args) dataset_type = "irpa" if "irpa" in dataset_type else "gguf" dataset = cli.get_input_dataset(args) @@ -110,7 +107,7 @@ def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]): fxb = FxProgramsBuilder(model) - def setup_cache(model, shard_count): + def setup_cache(model): if model.config.kv_cache_type == "paged": cache_state = model.cache.allocate( page_count=hp.context_length // llama_config.block_seq_stride @@ -161,7 +158,7 @@ def generate_batch_prefill(bs: int): sl_dim = llama_config.block_seq_stride * block_dim cache, cache_shard_dim, cache_dynamic_shapes, arg_affinities = setup_cache( - model, llama_config.tensor_parallelism_size + model ) # We need to offset the indices for the cache @@ -234,7 +231,7 @@ def generate_batch_decode(bs: int): cache_shard_dim, cache_dynamic_shapes, arg_affinities, - ) = setup_cache(model, llama_config.tensor_parallelism_size) + ) = setup_cache(model) # We need to offset the indices for the cache arg_affinities = {key + 4: arg_affinities[key] for key in arg_affinities} diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index d7ade43a7..fa6ba587b 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -300,7 +300,7 @@ def shard_state( """Shard an unsharded state. We can't just split the slab on the sub page dims. First it needs to be reinterpreted into the actual shape. - The split the head dimension, then flatten each shard. + Then split the head dimension, then flatten each shard. This is a work-around for the lack of block-cyclic sharded tensor type.""" if self.shard_count == 1: return state @@ -324,6 +324,9 @@ def shard_state( flat_sharded_page_table = SplitPrimitiveTensor(ts=shards, shard_dim=1) return [flat_sharded_page_table] + def unshard_state(self, state: list[SplitPrimitiveTensor]) -> list[torch.Tensor]: + return [ops.unshard(self.unflatten_page_table(state)).flatten(start_dim=1)] + @property def pad_sequence_stride(self) -> int: return self.block_seq_stride diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 656b4432b..d1cacefd1 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -186,29 +186,6 @@ def decode( self._assert_device(start_positions) self._assert_device(*cache_state, dtype=self.activation_dtype) - if self.config.tensor_parallelism_size > 1: - if not isinstance(tokens, ReplicatedTensor): - tokens = ops.replicate( - tokens, count=self.config.tensor_parallelism_size - ) - if not isinstance(attention_mask, ReplicatedTensor): - attention_mask = ops.replicate( - attention_mask, count=self.config.tensor_parallelism_size - ) - if not isinstance(start_positions, ReplicatedTensor): - start_positions = ops.replicate( - start_positions, count=self.config.tensor_parallelism_size - ) - if not isinstance(seq_block_ids, ReplicatedTensor): - seq_block_ids = ops.replicate( - seq_block_ids, count=self.config.tensor_parallelism_size - ) - # If the user provided unsharded arguments they probably want - # an unsharded result as well. - unshard_result = True - else: - unshard_result = False - bs, _ = tokens.shape # Precompute a position based mask for computing rope embeddings # as it is the same for all blocks. diff --git a/sharktank/sharktank/utils/testing.py b/sharktank/sharktank/utils/testing.py index 7b91b3a13..feebb25cd 100644 --- a/sharktank/sharktank/utils/testing.py +++ b/sharktank/sharktank/utils/testing.py @@ -14,9 +14,13 @@ from typing import Any, Callable from operator import eq from collections.abc import Iterable +import pytest +from sharktank.utils.tokenizer import InferenceTokenizer from ..types import * +longrun = pytest.mark.skipif("not config.getoption('longrun')") + # Range of torch.rand() is [0,1) # Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values def make_rand_torch(shape, dtype=torch.float32): @@ -31,6 +35,16 @@ def tearDown(self): shutil.rmtree(self._temp_dir, ignore_errors=True) +@pytest.mark.usefixtures("path_prefix") +class PathPrefixTestBase(TempDirTestBase): + """Creates a temporary directory and uses it if a path prefix is not given.""" + + def setUp(self): + super().setUp() + if self.path_prefix is None: + self.path_prefix = f"{self._temp_dir}/" + + class MainRunnerTestBase(TempDirTestBase): """Performs an in-process test of a `main(args)` func.""" @@ -54,6 +68,25 @@ def assertFileWritten(self, p: Path): self.assertGreater(p.stat().st_size, 0, msg=f"Expected file {p} had zero size") +class ModuloTokenizer(InferenceTokenizer): + """A tokenizer used for testing where we take a modulo of each character. + Guarantees that we are producing tokens of up to the max token ID.""" + + def __init__(self, vocabulary_size: int): + self.vocabulary_size = vocabulary_size + + def _encode(self, texts: list[str], add_start_token: bool) -> list[list[int]]: + return [ + [ord(character) % self.vocabulary_size for character in text] + for text in texts + ] + + def _decode(self, tokens: list[list[int]]) -> list[str]: + return [ + "".join([chr(token) for token in prompt_tokens]) for prompt_tokens in tokens + ] + + @contextlib.contextmanager def temporary_directory(identifier: str): """Returns a context manager TemporaryDirectory suitable for testing. diff --git a/sharktank/sharktank/utils/tokenizer.py b/sharktank/sharktank/utils/tokenizer.py index b459c706a..597533373 100644 --- a/sharktank/sharktank/utils/tokenizer.py +++ b/sharktank/sharktank/utils/tokenizer.py @@ -75,7 +75,7 @@ def pad_tokens( return token_ids, lengths @abstractmethod - def _encode(self, texts: list[str]) -> list[list[int]]: + def _encode(self, texts: list[str], add_start_token: bool) -> list[list[int]]: ... @abstractmethod diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py index 386061731..ee50d3d3a 100644 --- a/sharktank/tests/models/llama/sharded_llama_test.py +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -4,42 +4,282 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import unittest +from copy import deepcopy +from iree.compiler import compile_file, InputType +from typing import Any +import functools +import os import pytest -from typing import Any, List, Tuple, OrderedDict +import torch + +from sharktank.examples import export_paged_llm_v1 +from sharktank.examples.sharding import shard_llm_dataset +from sharktank.examples.paged_llm_v1 import TorchGenerator from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 -import sharktank.ops as ops -from sharktank.types import unbox_tensor, Dataset, UnreducedTensor, SplitPrimitiveTensor -from sharktank.models.llama.testing import make_random_llama_theta -from sharktank.utils.testing import skip -from sharktank.models.llama.sharding import shard_theta +from sharktank.layers import CausalLMModelABC from sharktank.layers.configs import LlamaHParams +from sharktank.layers.testing import CausalLMIreeModel +from sharktank.models.llama.sharding import shard_theta +from sharktank.models.llama.testing import make_random_llama_theta +from sharktank.types import ( + AnyTensor, + InferenceTensor, + DefaultPrimitiveTensor, + Dataset, + dtype_to_serialized_name, +) from sharktank.utils.math import round_up_to_multiple_of -from sharktank.utils import iterables_equal from sharktank.utils.iree import ( get_iree_devices, load_iree_module, - run_iree_module_function, - prepare_iree_module_function_args, - call_torch_module_function, - iree_to_torch, ) -from sharktank.export import export as sharktank_export -import tempfile -import torch -from copy import deepcopy -from iree.turbine.aot import FxProgramsBuilder, export -import iree.runtime -import numpy as np -import os +from sharktank.utils.testing import PathPrefixTestBase, ModuloTokenizer, longrun +from sharktank.utils.tokenizer import load_tokenizer, InferenceTokenizer +import sharktank.ops as ops + + +AnyTokenizer = Any + + +def set_float_dtype(tensor: InferenceTensor, dtype: torch.dtype) -> InferenceTensor: + if isinstance(tensor, DefaultPrimitiveTensor) and tensor.dtype.is_floating_point: + return DefaultPrimitiveTensor( + name=tensor.name, data=ops.to(tensor, dtype=dtype) + ) + assert False, "Unsupported tensor type" + + +def shard_dataset( + path: str, + output_path: str, + tensor_parallelism_size: int, + intermediates_caching: bool, +): + if not intermediates_caching or not os.path.exists(output_path): + if path.endswith(".gguf"): + dataset_arg = f"--gguf-file={path}" + elif path.endswith(".irpa"): + dataset_arg = f"--irpa-file={path}" + else: + raise ValueError(f'Invalid dataset filename "{dataset_arg}"') + shard_llm_dataset.main( + [ + f"--tensor-parallelism-size={tensor_parallelism_size}", + dataset_arg, + f"--output-irpa-file={output_path}", + ] + ) + + +def compile_iree_module( + intermediates_caching: bool, + config: LlamaModelConfig, + dataset_path: str, + batch_size: int, + target_device: str, + output_mlir_path: str, + output_module_path: str, + output_config_path: str, +): + if not intermediates_caching or not os.path.exists(output_module_path): + export_paged_llm_v1.main( + [ + f"--output-mlir={output_mlir_path}", + f"--irpa-file={dataset_path}", + f"--output-config={output_config_path}", + f"--bs={batch_size}", + f"--block-seq-stride={config.block_seq_stride}", + f"--attention-dtype={dtype_to_serialized_name(config.attention_dtype)}", + f"--activation-dtype={dtype_to_serialized_name(config.activation_dtype)}", + ] + ) + compiler_extra_args = [ + f"--iree-hal-target-device={target_device}[{i}]" + for i in range(config.tensor_parallelism_size) + ] + + compile_file( + output_mlir_path, + input_type=InputType.TORCH, + output_file=output_module_path, + extra_args=compiler_extra_args, + ) -@pytest.mark.usefixtures("caching", "path_prefix") -class ShardedLlamaTest(unittest.TestCase): +def assert_close_cache_state( + actual: list[torch.Tensor], + expected: list[torch.Tensor], +): + torch.testing.assert_close( + actual[0].to(dtype=expected[0].dtype), expected[0], atol=1e-3, rtol=0 + ) + + +def assert_close_logits( + actual: torch.Tensor, + expected: torch.Tensor, +): + actual_probabilities = torch.softmax(actual, dim=1) + expected_probabilities = torch.softmax(expected, dim=1) + torch.testing.assert_close( + actual_probabilities.to(dtype=expected_probabilities.dtype), + expected_probabilities, + atol=1e-3, + rtol=0, + ) + + +def raise_multiple(errors): + if not errors: # list emptied, recursion ends + return + try: + raise errors.pop() # pop removes list entries + finally: + raise_multiple(errors) # recursion + + +def assert_close_post_call( + actual_logits: torch.Tensor, + expected_logits: torch.Tensor, + actual_cache_state: list[AnyTensor], + expected_cache_state: list[AnyTensor], +): + errors = [] + try: + assert_close_logits(actual_logits, expected_logits) + except Exception as ex: + errors.append(ex) + try: + assert_close_cache_state(actual_cache_state, expected_cache_state) + except Exception as ex: + errors.append(ex) + raise_multiple(errors) + + +def compare_models( + target_model: CausalLMModelABC, + reference_model: CausalLMModelABC, + tokenizer: InferenceTokenizer, + cache_page_count: int, + prompts: list[str], +): + generator = TorchGenerator( + target_model, tokenizer, page_cache_size=cache_page_count + ) + reference_generator = TorchGenerator( + reference_model, tokenizer, page_cache_size=cache_page_count + ) + batch = generator.begin_batch(prompts) + reference_batch = reference_generator.begin_batch(prompts) + + # Init the cache and copy it to both the target and the reference. + unsharded_reference_cache_state = reference_model.cache.paged.unshard_state( + reference_batch.cache_state + ) + torch.full( + size=unsharded_reference_cache_state[0].shape, + fill_value=0, + out=unsharded_reference_cache_state[0], + ) + reference_batch.cache_state[0][...] = reference_model.cache.paged.shard_state( + unsharded_reference_cache_state + )[0] + batch.cache_state[0][...] = target_model.cache.paged.shard_state( + unsharded_reference_cache_state + )[0] + + batch.prefill() + reference_batch.prefill() + assert_close_post_call( + actual_logits=batch.logits, + expected_logits=reference_batch.logits, + actual_cache_state=target_model.cache.paged.unshard_state(batch.cache_state), + expected_cache_state=reference_batch.cache_state, + ) + + batch.decode() + reference_batch.decode() + assert_close_post_call( + actual_logits=batch.logits, + expected_logits=reference_batch.logits, + actual_cache_state=target_model.cache.paged.unshard_state(batch.cache_state), + expected_cache_state=reference_batch.cache_state, + ) + + +def run_test_compare_iree_against_torch( + path_prefix: str, + intermediates_caching: bool, + torch_dataset_path: str, + torch_config: LlamaModelConfig, + iree_dataset_path: str, + iree_config: LlamaModelConfig, + iree_target_device: str, + iree_driver: str, + tokenizer: InferenceTokenizer, + prompts: list[str], + cache_page_count: int, +): + iree_module_path = f"{path_prefix}program.vmfb" + compile_iree_module( + intermediates_caching=intermediates_caching, + config=iree_config, + dataset_path=iree_dataset_path, + batch_size=len(prompts), + target_device=iree_target_device, + output_mlir_path=f"{path_prefix}program.mlir", + output_module_path=iree_module_path, + output_config_path=f"{path_prefix}program_config.json", + ) + iree_devices = get_iree_devices( + driver=iree_driver, + device_count=iree_config.tensor_parallelism_size, + ) + iree_module, vm_context, vm_instance = load_iree_module( + module_path=iree_module_path, + devices=iree_devices, + parameters_path=iree_dataset_path, + ) + iree_model = CausalLMIreeModel( + batch_size=len(prompts), + config=iree_config, + vm_context=vm_context, + iree_driver=iree_driver, + iree_module=iree_module, + iree_devices=iree_devices, + ) + + torch_dataset = Dataset.load(torch_dataset_path, mmap=False) + torch_model = PagedLlamaModelV1(theta=torch_dataset.root_theta, config=torch_config) + + compare_models( + target_model=iree_model, + reference_model=torch_model, + tokenizer=tokenizer, + cache_page_count=cache_page_count, + prompts=prompts, + ) + + +@pytest.mark.usefixtures("caching") +class ShardedLlamaTestBase(PathPrefixTestBase): def setUp(self): + super().setUp() torch.random.manual_seed(123456) - self.dtype = torch.float32 - torch.set_default_dtype(self.dtype) + self.intermediates_caching = self.caching + self.prompts = [ + "The sky is blue", + "The night is dark", + "Linguistics is the study of", + ] + + +class ShardedLlamaToySizedTest(ShardedLlamaTestBase): + def setUp(self): + super().setUp() + self.reference_dtype = torch.float64 + self.target_dtype = torch.float32 + torch.set_default_dtype(self.reference_dtype) self.batch_size = 3 self.attention_head_count_kv = 4 self.attention_head_count = self.attention_head_count_kv * 5 @@ -47,10 +287,16 @@ def setUp(self): self.rope_dimension_count = 7 * 2 self.attn_head_dim = self.rope_dimension_count self.block_seq_stride = 13 + self.context_length = round_up_to_multiple_of( + functools.reduce(max, [len(prompt) for prompt in self.prompts]), + self.block_seq_stride, + ) + # Make this large enough to make torch.export.Dim happy. + self.context_length = max(self.context_length, 4 * self.block_seq_stride) self.cache_page_count = 11 self.config = LlamaModelConfig( hp=LlamaHParams( - context_length=self.block_seq_stride * 2, + context_length=self.context_length, embedding_length=self.attention_head_count * self.attn_head_dim, block_count=3, feed_forward_length=23, @@ -65,342 +311,157 @@ def setUp(self): model_arch="llama", ), block_seq_stride=self.block_seq_stride, - activation_dtype=self.dtype, - attention_dtype=self.dtype, + activation_dtype=self.reference_dtype, + attention_dtype=self.reference_dtype, + static_tables=False, ) self.sharded_config = deepcopy(self.config) self.sharded_config.tensor_parallelism_size = 2 + self.sharded_config.activation_dtype = self.target_dtype + self.sharded_config.attention_dtype = self.target_dtype + self.theta = make_random_llama_theta( config=self.config, vocab_size=self.vocabulary_size, ) - self.prefill_seq_lens = torch.tensor( - [14, 9, self.block_seq_stride - 1], dtype=torch.int64 - ) + self.theta.rename_tensors_to_paths() - def make_prefill_args(self, model: PagedLlamaModelV1) -> OrderedDict[str, Any]: - batch_seq_len = round_up_to_multiple_of( - int(torch.max(self.prefill_seq_lens)), model.cache.pad_sequence_stride - ) - token_ids = torch.randint( - low=0, - high=self.vocabulary_size, - size=[self.batch_size, batch_seq_len], - dtype=torch.int32, - ) - attention_mask = model.attention_mask( - model.input_mask(self.prefill_seq_lens, batch_seq_len) - ) - seq_block_ids = torch.arange( - self.batch_size * batch_seq_len // self.config.block_seq_stride - ).view(self.batch_size, -1) - cache_state = model.cache.paged.allocate(page_count=self.cache_page_count) - cache_state = [torch.rand_like(cache_state[0])] - return OrderedDict( - [ - ("tokens", token_ids), - ("attention_mask", attention_mask), - ("seq_block_ids", seq_block_ids), - ("cache_state", cache_state), - ] - ) + self.tokenizer = ModuloTokenizer(self.vocabulary_size) - def make_equal_unsharded_and_sharded_prefill_args( - self, model: PagedLlamaModelV1, sharded_model: PagedLlamaModelV1 - ) -> Tuple[OrderedDict[str, Any], OrderedDict[str, Any]]: - prefill_kwargs = self.make_prefill_args(model) - sharded_cache_state = sharded_model.cache.paged.allocate( - page_count=self.cache_page_count - ) - assert iterables_equal( - prefill_kwargs["cache_state"][0].shape, sharded_cache_state[0].shape - ) - sharded_prefill_kwargs = deepcopy(prefill_kwargs) - sharded_cache_state = sharded_model.cache.paged.shard_state( - sharded_prefill_kwargs["cache_state"] - ) - sharded_prefill_kwargs["cache_state"] = sharded_cache_state - - sharding = sharded_model.config.tensor_parallelism_size - for k in sharded_prefill_kwargs: - if k == "cache_state": - continue - sharded_prefill_kwargs[k] = ops.replicate( - sharded_prefill_kwargs[k], count=sharding - ) - - return prefill_kwargs, sharded_prefill_kwargs - - def make_decode_args(self, model: PagedLlamaModelV1) -> OrderedDict[str, Any]: - start_positions = self.prefill_seq_lens.clone() - seq_lens = self.prefill_seq_lens + 1 - batch_seq_len = round_up_to_multiple_of( - int(torch.max(seq_lens)), model.cache.pad_sequence_stride + def testCompareTensorParallelToUnsharded(self): + """Run a sharded variant of a toy model size and compare it against the + unsharded variant.""" + sharded_theta = self.theta.transform( + functools.partial(set_float_dtype, dtype=self.target_dtype) ) - decode_token_ids = torch.randint( - low=0, - high=self.vocabulary_size, - size=[self.batch_size, 1], - dtype=torch.int32, + sharded_theta = shard_theta(sharded_theta, self.sharded_config) + sharded_model = PagedLlamaModelV1(sharded_theta, self.sharded_config) + reference_model = PagedLlamaModelV1(self.theta, self.config) + compare_models( + target_model=sharded_model, + reference_model=reference_model, + tokenizer=self.tokenizer, + prompts=self.prompts, + cache_page_count=self.cache_page_count, ) - attention_mask = model.decode_attention_mask( - model.input_mask(seq_lens, batch_seq_len) + + def testCompareTensorParallelWithIreeToUnsharded(self): + """Test exporting to MLIR and compiling with IREE the sharded Llama model. + Test numerical accuracy of the IREE module against PyTorch.""" + + dataset = Dataset( + properties=self.config.hp.to_gguf_props(), root_theta=self.theta ) - seq_block_ids = torch.arange( - self.batch_size * batch_seq_len // self.config.block_seq_stride - ).view(self.batch_size, -1) - cache_state = model.cache.paged.allocate(page_count=self.cache_page_count) - cache_state = [torch.rand_like(cache_state[0])] - return OrderedDict( - [ - ("tokens", decode_token_ids), - ("attention_mask", attention_mask), - ("start_positions", start_positions), - ("seq_block_ids", seq_block_ids), - ("cache_state", cache_state), - ] + torch_dataset_path = f"{self.path_prefix}torch-reference-dataset.irpa" + if not self.intermediates_caching or not os.path.exists(torch_dataset_path): + dataset.save(torch_dataset_path) + + iree_unsharded_theta = self.theta.transform( + functools.partial(set_float_dtype, dtype=self.target_dtype) + ) + iree_unsharded_dataset = Dataset( + properties=self.sharded_config.hp.to_gguf_props(), + root_theta=iree_unsharded_theta, + ) + iree_usharded_dataset_path = f"{self.path_prefix}iree-dataset-unsharded.irpa" + if not self.intermediates_caching or not os.path.exists( + iree_usharded_dataset_path + ): + iree_unsharded_dataset.save(iree_usharded_dataset_path) + + iree_dataset_path = f"{self.path_prefix}iree-dataset.irpa" + + shard_dataset( + path=iree_usharded_dataset_path, + output_path=iree_dataset_path, + tensor_parallelism_size=self.sharded_config.tensor_parallelism_size, + intermediates_caching=self.intermediates_caching, ) - def make_equal_unsharded_and_sharded_decode_args( - self, model: PagedLlamaModelV1, sharded_model: PagedLlamaModelV1 - ) -> Tuple[OrderedDict[str, Any], OrderedDict[str, Any]]: - decode_kwargs = self.make_decode_args(model) - sharded_decode_kwargs = deepcopy(decode_kwargs) - sharded_decode_kwargs["cache_state"] = sharded_model.cache.paged.shard_state( - sharded_decode_kwargs["cache_state"] + run_test_compare_iree_against_torch( + path_prefix=self.path_prefix, + intermediates_caching=self.intermediates_caching, + torch_dataset_path=torch_dataset_path, + torch_config=self.config, + iree_dataset_path=iree_dataset_path, + iree_config=self.sharded_config, + iree_target_device="llvm-cpu", + iree_driver="local-task", + tokenizer=self.tokenizer, + prompts=self.prompts, + cache_page_count=self.cache_page_count, ) - sharding = sharded_model.config.tensor_parallelism_size - for k in sharded_decode_kwargs: - if k == "cache_state": - continue - sharded_decode_kwargs[k] = ops.replicate( - sharded_decode_kwargs[k], count=sharding - ) - return decode_kwargs, sharded_decode_kwargs +@pytest.mark.usefixtures("get_model_path") +class Llama38BFp16Tp8Test(ShardedLlamaTestBase): + def setUp(self): + super().setUp() + tokenizer_path = self.llama3_8b_tokenizer + self.tokenizer = load_tokenizer(tokenizer_path.parent) - def testCompareToySizedModelToUnsharded(self): - """Run a sharded variant of a toy model size and compare it against the - unsharded variant.""" - model = PagedLlamaModelV1(self.theta, self.config) - sharded_theta = shard_theta(self.theta, self.sharded_config) - sharded_model = PagedLlamaModelV1(sharded_theta, self.sharded_config) + self.reference_dtype = torch.float64 + self.dataset_path = str(self.llama3_8b_f16_model) + self.batch_size = 4 + self.cache_page_count = 8192 + tensor_parallelism_size = 8 - # Verify prefill step. - ( - prefill_kwargs, - sharded_prefill_kwargs, - ) = self.make_equal_unsharded_and_sharded_prefill_args(model, sharded_model) - - expected_prefill_result = model.prefill(**prefill_kwargs) - sharded_prefill_result = sharded_model.prefill(**sharded_prefill_kwargs) - sharded_prefill_result = ops.unshard(sharded_prefill_result) - # The errors are quite high, but for float64 both errors drop to < 1e-12. - # The numerics are probably correct. - torch.testing.assert_close( - sharded_prefill_result, expected_prefill_result, atol=1e-3, rtol=1e-2 - ) - expected_cache_state = prefill_kwargs["cache_state"][0] - actual_cache_state = ops.unshard( - sharded_model.cache.paged.unflatten_page_table( - sharded_prefill_kwargs["cache_state"] - ) - ).flatten(start_dim=1) - torch.testing.assert_close( - actual_cache_state, expected_cache_state, atol=1e-4, rtol=1e-1 - ) + dataset = Dataset.load(self.dataset_path) + self.theta = dataset.root_theta - # Verify decode step. - ( - decode_kwargs, - sharded_decode_kwargs, - ) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model) - expected_decode_result = model.decode(**decode_kwargs) - sharded_decode_result = sharded_model.decode(**sharded_decode_kwargs) - sharded_decode_result = ops.unshard(sharded_decode_result) - torch.testing.assert_close( - sharded_decode_result, expected_decode_result, atol=1e-4, rtol=1e-5 + self.config = LlamaModelConfig( + hp=LlamaHParams.from_gguf_props(dataset.properties), + activation_dtype=self.reference_dtype, + attention_dtype=self.reference_dtype, + static_tables=False, ) - expected_decode_cache_state = decode_kwargs["cache_state"][0] - actual_decode_cache_state = ops.unshard( - sharded_model.cache.paged.unflatten_page_table( - sharded_decode_kwargs["cache_state"] - ) - ).flatten(start_dim=1) - # TODO: investigate why the Windows machine CI is producing a larger numerical - # error. - # The Ubuntu CI runs fine with default tolerances. - torch.testing.assert_close( - actual_decode_cache_state, expected_decode_cache_state, atol=1e-4, rtol=1e-4 + self.sharded_config = LlamaModelConfig( + hp=LlamaHParams.from_gguf_props(dataset.properties), + tensor_parallelism_size=tensor_parallelism_size, + static_tables=False, # Rely on the compiler for hoisting tables. ) - @skip( - ( - "Before this does not crash at all we need " - "https://github.com/iree-org/iree/pull/18663 merged." - ) + def tearDown(self): + # make sure we don't reference the memory mapped file. + del self.theta + super().tearDown() + + @longrun + @pytest.mark.xfail( + reason="Numerics are not close.", raises=AssertionError, strict=True ) - def testExportAndRunToySizedModelWithIree(self): + def testCompareTensorParallelWithIreeToUnsharded(self): """Test exporting to MLIR and compiling with IREE the sharded Llama model. Test numerical accuracy of the IREE module against PyTorch.""" - if self.path_prefix is not None: - self.runTestExportAndRunToySizedModelWithIree( - path_prefix=self.path_prefix, dump_enabled=True - ) - else: - with tempfile.TemporaryDirectory() as temp_dir: - self.runTestExportAndRunToySizedModelWithIree( - path_prefix=f"{temp_dir}/", dump_enabled=False - ) - - def runTestExportAndRunToySizedModelWithIree( - self, path_prefix: str, dump_enabled: bool - ): - sharded_theta = shard_theta(self.theta, self.sharded_config) - sharded_theta.rename_tensors_to_paths() - sharded_dataset = Dataset({}, sharded_theta) - sharded_parameters_path = f"{path_prefix}parameters.irpa" - sharded_dataset.save(sharded_parameters_path) - sharded_dataset = Dataset.load(sharded_parameters_path, mmap=False) - iree_driver = "local-task" - - model = PagedLlamaModelV1(self.theta, self.config) - sharded_model = PagedLlamaModelV1( - sharded_dataset.root_theta, self.sharded_config - ) - ( - _, - sharded_prefill_kwargs, - ) = self.make_equal_unsharded_and_sharded_prefill_args(model, sharded_model) - ( - _, - sharded_decode_kwargs, - ) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model) - - iree_module_path = f"{path_prefix}program.vmfb" - if not self.caching or not os.path.exists(iree_module_path): - # Export and compile the IREE module. - sharded_fxb = FxProgramsBuilder(sharded_model) - - @sharktank_export( - fx_builder=sharded_fxb, - name="prefill", - kwargs=sharded_prefill_kwargs, - strict=False, - ) - def _(model, *args, **kwargs) -> torch.Tensor: - return model.prefill(*args, **kwargs) - - # TODO: remove strict=False when - # https://github.com/pytorch/pytorch/issues/136757 - # is resolved. - @sharktank_export( - fx_builder=sharded_fxb, - name="decode", - kwargs=sharded_decode_kwargs, - strict=False, - ) - def _(model, *args, **kwargs) -> torch.Tensor: - return model.decode(*args, **kwargs) - - output = export(sharded_fxb) - if dump_enabled: - output.save_mlir(f"{path_prefix}program.mlir") - output.session.set_flags( - *[ - f"--iree-hal-target-device=llvm-cpu[{i}]" - for i in range(self.sharded_config.tensor_parallelism_size) - ] - ) - output.compile( - save_to=iree_module_path, - target_backends=None, - ) - - iree_devices = get_iree_devices( - driver=iree_driver, - device_count=self.sharded_config.tensor_parallelism_size, - ) - iree_module, vm_context, vm_instance = load_iree_module( - module_path=iree_module_path, - devices=iree_devices, - parameters_path=sharded_parameters_path, - ) - - # Run prefill step. - prefill_iree_args = prepare_iree_module_function_args( - args=deepcopy(sharded_prefill_kwargs).values(), devices=iree_devices + reference_theta = self.theta.transform( + functools.partial(set_float_dtype, dtype=self.reference_dtype) ) - for i, arg in enumerate(prefill_iree_args): - np.save(f"{path_prefix}prefill_arg{i}.npy", arg.to_host()) - prefill_iree_result = run_iree_module_function( - args=prefill_iree_args, - function_name="prefill", - module=iree_module, - vm_context=vm_context, - driver=iree_driver, - trace_path_prefix=path_prefix if dump_enabled else None, - ) - prefill_iree_result = UnreducedTensor(ts=iree_to_torch(*prefill_iree_result)) - expected_prefill_result = call_torch_module_function( - module=sharded_model, - function_name="prefill", - kwargs=sharded_prefill_kwargs, - trace_path_prefix=f"{path_prefix}expected_" if dump_enabled else None, - ) - prefill_iree_cache_state_shards = prefill_iree_args[ - -self.config.tensor_parallelism_size - 1 : - ] - prefill_iree_cache_state = SplitPrimitiveTensor( - ts=iree_to_torch(*prefill_iree_cache_state_shards), - shard_dim=sharded_prefill_kwargs["cache_state"][0].shard_dim, + reference_dataset = Dataset( + properties=self.config.hp.to_gguf_props(), root_theta=reference_theta ) + reference_dataset_path = f"{self.path_prefix}torch-reference-dataset.irpa" + if not self.intermediates_caching or not os.path.exists(reference_dataset_path): + reference_dataset.save(reference_dataset_path) + target_dataset_path = f"{self.path_prefix}iree-dataset.irpa" - # Run decode step. - decode_iree_args = prepare_iree_module_function_args( - args=deepcopy(sharded_decode_kwargs).values(), devices=iree_devices - ) - decode_iree_result = run_iree_module_function( - args=decode_iree_args, - function_name="decode", - module=iree_module, - vm_context=vm_context, - driver=iree_driver, - trace_path_prefix=path_prefix if dump_enabled else None, - ) - decode_iree_result = UnreducedTensor(ts=iree_to_torch(*decode_iree_result)) - expected_decode_result = call_torch_module_function( - module=sharded_model, - function_name="decode", - kwargs=sharded_decode_kwargs, - trace_path_prefix=f"{path_prefix}expected_" if dump_enabled else None, - ) - decode_iree_cache_state_shards = decode_iree_args[ - -self.config.tensor_parallelism_size - 1 : - ] - decode_iree_cache_state = SplitPrimitiveTensor( - ts=iree_to_torch(*decode_iree_cache_state_shards), - shard_dim=sharded_decode_kwargs["cache_state"][0].shard_dim, + shard_dataset( + path=self.dataset_path, + output_path=target_dataset_path, + tensor_parallelism_size=self.sharded_config.tensor_parallelism_size, + intermediates_caching=self.intermediates_caching, ) - # Check IREE's numerical correctness against PyTorch. - # TODO: Although, not entirely wrong, investigate why this accuracy is that - # low for fp32 (atol=0.0011, rtol=0.013). - torch.testing.assert_close( - ops.unshard(prefill_iree_result), - ops.unshard(expected_prefill_result), - ) - torch.testing.assert_close( - ops.unshard(prefill_iree_cache_state), - ops.unshard(sharded_prefill_kwargs["cache_state"][0]), - ) - torch.testing.assert_close( - ops.unshard(decode_iree_result), - ops.unshard(expected_decode_result), - ) - torch.testing.assert_close( - ops.unshard(decode_iree_cache_state), - ops.unshard(sharded_decode_kwargs["cache_state"][0]), + run_test_compare_iree_against_torch( + path_prefix=self.path_prefix, + intermediates_caching=self.intermediates_caching, + torch_dataset_path=self.dataset_path, + torch_config=self.config, + iree_dataset_path=target_dataset_path, + iree_config=self.sharded_config, + iree_target_device="llvm-cpu", + iree_driver="local-task", + tokenizer=self.tokenizer, + prompts=self.prompts, + cache_page_count=self.cache_page_count, )