From 9ca19a524f7c86ecfd2202e1544bc1b9deab4a9c Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Mon, 7 Oct 2024 14:11:25 -0700 Subject: [PATCH 01/15] CI/repo changes for shark-turbine to iree-turbine rename (#260) This commit makes the necessary changes to adapt to the new iree-turbine namespace. Related PR: https://github.com/iree-org/iree-turbine/pull/197 Signed-off-by: saienduri --- .github/workflows/ci.yaml | 2 +- .github/workflows/test.yaml | 4 ++-- README.md | 2 +- docs/model_cookbook.md | 2 +- docs/quantization.md | 6 +++--- sharktank/setup.py | 2 +- sharktank/sharktank/examples/export_paged_llm_v1.py | 2 +- sharktank/sharktank/examples/sharding/export_ffn_net.py | 2 +- sharktank/sharktank/examples/sharding/export_gemm.py | 2 +- sharktank/sharktank/export_layer/export_moe.py | 2 +- sharktank/sharktank/export_layer/export_paged_attention.py | 2 +- sharktank/sharktank/kernels/base.py | 6 +++--- sharktank/sharktank/models/punet/tools/run_punet.py | 2 +- sharktank/sharktank/ops/default_impls.py | 4 ++-- sharktank/sharktank/types/gguf_interop/base.py | 2 +- sharktank/sharktank/types/tensors.py | 2 +- sharktank/sharktank/types/theta.py | 2 +- sharktank/sharktank/utils/io.py | 2 +- sharktank/sharktank/utils/logging.py | 2 +- sharktank/tests/kernels/batch_matmul_transpose_b_test.py | 2 +- sharktank/tests/kernels/conv_2d_nchw_fchw_test.py | 2 +- sharktank/tests/kernels/einsum_q4_test.py | 2 +- sharktank/tests/kernels/mmt_block_scaled_offset_q4_test.py | 2 +- sharktank/tests/kernels/mmt_block_scaled_q8_test.py | 2 +- .../tests/kernels/mmt_super_block_scaled_offset_q4_test.py | 2 +- sharktank/tests/kernels/mmtfp_test.py | 2 +- sharktank/tests/kernels/pooling_nchw_sum_test.py | 2 +- sharktank/tests/layers/sharded_conv2d_with_iree_test.py | 2 +- sharktank/tests/models/llama/moe_block_test.py | 2 +- sharktank/tests/models/llama/sharded_llama_test.py | 2 +- .../models/punet/sharded_resnet_block_with_iree_test.py | 2 +- sharktank/tests/types/dataset_test.py | 2 +- turbine-requirements.txt | 2 +- 33 files changed, 39 insertions(+), 39 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 2cf75d06a..90aa8220c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -54,7 +54,7 @@ jobs: # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --no-compile -r pytorch-cpu-requirements.txt pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ - -e "git+https://github.com/iree-org/iree-turbine.git#egg=shark-turbine" + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" pip install --no-compile -r requirements.txt -e sharktank/ - name: Run sharktank tests diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 6eb519717..8b3f50944 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -53,7 +53,7 @@ jobs: # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --no-compile -r pytorch-cpu-requirements.txt pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ - -e "git+https://github.com/iree-org/iree-turbine.git#egg=shark-turbine" + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" pip install --no-compile -r requirements.txt -e sharktank/ shortfin/ # Try with the latest nightly releases, not what iree-turbine pins. @@ -85,7 +85,7 @@ jobs: python -m pip install --no-compile --upgrade pip pip install --no-compile -r pytorch-rocm-requirements.txt pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ - -e "git+https://github.com/iree-org/iree-turbine.git#egg=shark-turbine" + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" pip install --no-compile -r requirements.txt -e sharktank/ shortfin/ - name: Run punet tests diff --git a/README.md b/README.md index 7ff8d0126..f8d001c88 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ pip install -r pytorch-rocm-requirements.txt ``` # Clone and install editable iree-turbine dep in deps/ pip install -f https://iree.dev/pip-release-links.html --src deps \ - -e "git+https://github.com/iree-org/iree-turbine.git#egg=shark-turbine" + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" # Install editable local projects. pip install -r requirements.txt -e sharktank/ shortfin/ diff --git a/docs/model_cookbook.md b/docs/model_cookbook.md index 64137956d..becf40820 100644 --- a/docs/model_cookbook.md +++ b/docs/model_cookbook.md @@ -176,7 +176,7 @@ source .venv/bin/activate # Install requirements. pip install -r pytorch-cpu-requirements.txt pip install -f https://iree.dev/pip-release-links.html --src deps \ - -e "git+https://github.com/iree-org/iree-turbine.git#egg=shark-turbine" + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" # Install local projects. pip install -r requirements.txt -e sharktank/ shortfin/ diff --git a/docs/quantization.md b/docs/quantization.md index 0563e8108..51b26b3b5 100644 --- a/docs/quantization.md +++ b/docs/quantization.md @@ -277,9 +277,9 @@ is everything). We're just starting to exploit some of this as the PyTorch level. Some examples: * Something as simple as a humble runtime -[tensor trace/print](https://github.com/iree-org/iree-turbine/blob/main/shark_turbine/ops/iree.py#L52) -* [Simple linalg based template expansion](https://github.com/iree-org/iree-turbine/blob/main/shark_turbine/ops/_jinja_test_ops.py#L28) - (see backing example [jinja template](https://github.com/iree-org/iree-turbine/blob/main/shark_turbine/ops/templates/test_add_jinja.mlir)). +[tensor trace/print](https://github.com/iree-org/iree-turbine/blob/main/iree.turbine/ops/iree.py#L52) +* [Simple linalg based template expansion](https://github.com/iree-org/iree-turbine/blob/main/iree.turbine/ops/_jinja_test_ops.py#L28) + (see backing example [jinja template](https://github.com/iree-org/iree-turbine/blob/main/iree.turbine/ops/templates/test_add_jinja.mlir)). * Optimal linalg-based [8-bit block scaled mmt for weight compression](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/kernels/mmt_block_scaled_q8.py) (see backing [jinja template](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/kernels/templates/mmt_block_scaled_q8_3d.mlir)). * DSL based [like this fused attention kernel](https://github.com/iree-org/iree-turbine/blob/main/tests/kernel/fused_attention_test.py#L20) diff --git a/sharktank/setup.py b/sharktank/setup.py index b8caf9e7d..21be90019 100644 --- a/sharktank/setup.py +++ b/sharktank/setup.py @@ -95,7 +95,7 @@ def initialize_options(self): "sharktank": ["py.typed", "kernels/templates/*.mlir"], }, install_requires=[ - "shark-turbine", + "iree-turbine", ], extras_require={ "testing": [ diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 3e094b494..484d094e3 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -9,7 +9,7 @@ import json import torch -from shark_turbine.aot import * +from iree.turbine.aot import * from sharktank.layers import * from sharktank.types import * diff --git a/sharktank/sharktank/examples/sharding/export_ffn_net.py b/sharktank/sharktank/examples/sharding/export_ffn_net.py index f80b9a2ac..4885b7d54 100644 --- a/sharktank/sharktank/examples/sharding/export_ffn_net.py +++ b/sharktank/sharktank/examples/sharding/export_ffn_net.py @@ -89,7 +89,7 @@ def main(raw_args=None): ds = Dataset.load(args.output_irpa_file) mdl = ShardedFFN(ds.root_theta) - from shark_turbine import aot + from iree.turbine import aot example_arg = torch.empty(bs, sl, primary_dim, dtype=torch.float16) ep = torch.export.export(mdl, (example_arg,)) diff --git a/sharktank/sharktank/examples/sharding/export_gemm.py b/sharktank/sharktank/examples/sharding/export_gemm.py index 7a4322e38..9744a6d82 100644 --- a/sharktank/sharktank/examples/sharding/export_gemm.py +++ b/sharktank/sharktank/examples/sharding/export_gemm.py @@ -4,7 +4,7 @@ import torch from torch import Tensor from sharktank import ops -from shark_turbine import aot +from iree.turbine import aot def export_gemm( diff --git a/sharktank/sharktank/export_layer/export_moe.py b/sharktank/sharktank/export_layer/export_moe.py index e8f257bfe..f2c10c4b4 100644 --- a/sharktank/sharktank/export_layer/export_moe.py +++ b/sharktank/sharktank/export_layer/export_moe.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import torch -from shark_turbine.aot import * +from iree.turbine.aot import * from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock from ..utils import cli diff --git a/sharktank/sharktank/export_layer/export_paged_attention.py b/sharktank/sharktank/export_layer/export_paged_attention.py index aa0cdf961..9cea8bed4 100644 --- a/sharktank/sharktank/export_layer/export_paged_attention.py +++ b/sharktank/sharktank/export_layer/export_paged_attention.py @@ -13,7 +13,7 @@ import torch.nn.functional as F -from shark_turbine.aot import * +from iree.turbine.aot import * from sharktank.layers import * from sharktank.types import * diff --git a/sharktank/sharktank/kernels/base.py b/sharktank/sharktank/kernels/base.py index 8c99c81d9..ce792b525 100644 --- a/sharktank/sharktank/kernels/base.py +++ b/sharktank/sharktank/kernels/base.py @@ -12,7 +12,7 @@ from jinja2 import Environment, PackageLoader, select_autoescape -from shark_turbine.support.ir_imports import ( +from iree.turbine.support.ir_imports import ( FlatSymbolRefAttr, FunctionType, IrType, @@ -24,7 +24,7 @@ Value, ) -from shark_turbine.runtime.op_reg import ( +from iree.turbine.runtime.op_reg import ( def_library, CustomOp, KernelBuilder, @@ -32,7 +32,7 @@ TensorArg, ) -from shark_turbine.transforms.merger import Merger +from iree.turbine.transforms.merger import Merger from ..utils.logging import get_logger diff --git a/sharktank/sharktank/models/punet/tools/run_punet.py b/sharktank/sharktank/models/punet/tools/run_punet.py index ace279a3b..b2ad58d9d 100644 --- a/sharktank/sharktank/models/punet/tools/run_punet.py +++ b/sharktank/sharktank/models/punet/tools/run_punet.py @@ -9,7 +9,7 @@ import torch -from shark_turbine import aot +from iree.turbine import aot from ..model import Unet2DConditionModel, ClassifierFreeGuidanceUnetModel from ....utils.patching import SaveModuleResultTensorsPatch diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index ed6e6c730..fec30fca6 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -24,7 +24,7 @@ from ..types.tensors import unbox_tensor, AnyTensor from ._registry import AllOfType, AllOfExprs, AllOfExprsVariadic, IsOfType from .signatures import * -import shark_turbine.ops.iree +import iree.turbine.ops.iree @cat.override(AllOfType(Tensor, PrimitiveTensor)) @@ -393,7 +393,7 @@ def to_default(tensor: Tensor, *args, **kwargs): @transfer_to_logical_device.override(Tensor) def transfer_to_logical_device_default(tensor: Tensor, ordinal: int): - return shark_turbine.ops.iree.transfer_to_logical_device( + return iree.turbine.ops.iree.transfer_to_logical_device( f"{ordinal}", unbox_tensor(tensor) ) diff --git a/sharktank/sharktank/types/gguf_interop/base.py b/sharktank/sharktank/types/gguf_interop/base.py index ab383a14c..44674bc83 100644 --- a/sharktank/sharktank/types/gguf_interop/base.py +++ b/sharktank/sharktank/types/gguf_interop/base.py @@ -13,7 +13,7 @@ from gguf import GGUFReader, GGUFValueType -from shark_turbine.aot import ( +from iree.turbine.aot import ( ExternalTensorTrait, ) diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 93aac9e34..200800d44 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -27,7 +27,7 @@ from torch import Tensor from torch.utils._pytree import register_pytree_node, SequenceKey from ..utils.math import ceildiv -from shark_turbine.aot import ( +from iree.turbine.aot import ( ExternalTensorTrait, ) from ..utils import tree as tree_utils diff --git a/sharktank/sharktank/types/theta.py b/sharktank/sharktank/types/theta.py index 3537726ec..29bc29bb8 100644 --- a/sharktank/sharktank/types/theta.py +++ b/sharktank/sharktank/types/theta.py @@ -15,7 +15,7 @@ import torch import torch.nn.functional as F -from shark_turbine.aot import ( +from iree.turbine.aot import ( ExternalTensorTrait, ParameterArchive, ParameterArchiveEntry, diff --git a/sharktank/sharktank/utils/io.py b/sharktank/sharktank/utils/io.py index 62fd78f33..ac2480846 100644 --- a/sharktank/sharktank/utils/io.py +++ b/sharktank/sharktank/utils/io.py @@ -6,7 +6,7 @@ from pathlib import Path -from shark_turbine.aot import ( +from iree.turbine.aot import ( ParameterArchiveBuilder, ) diff --git a/sharktank/sharktank/utils/logging.py b/sharktank/sharktank/utils/logging.py index 977462d86..3801f96cb 100644 --- a/sharktank/sharktank/utils/logging.py +++ b/sharktank/sharktank/utils/logging.py @@ -6,7 +6,7 @@ import logging -from shark_turbine.support.logging import get_logger +from iree.turbine.support.logging import get_logger transform_logger: logging.Logger = get_logger("sharktank.transforms") diff --git a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py index 30cc2296c..208d54782 100644 --- a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py +++ b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py @@ -13,7 +13,7 @@ import torch -from shark_turbine import aot +from iree.turbine import aot from sharktank import kernels diff --git a/sharktank/tests/kernels/conv_2d_nchw_fchw_test.py b/sharktank/tests/kernels/conv_2d_nchw_fchw_test.py index b03293523..637bf74c8 100644 --- a/sharktank/tests/kernels/conv_2d_nchw_fchw_test.py +++ b/sharktank/tests/kernels/conv_2d_nchw_fchw_test.py @@ -13,7 +13,7 @@ import torch -from shark_turbine import aot +from iree.turbine import aot from sharktank import kernels from sharktank.ops.qconv_impls import _pad_last_2d diff --git a/sharktank/tests/kernels/einsum_q4_test.py b/sharktank/tests/kernels/einsum_q4_test.py index d94ec5851..5f037ba9a 100644 --- a/sharktank/tests/kernels/einsum_q4_test.py +++ b/sharktank/tests/kernels/einsum_q4_test.py @@ -13,7 +13,7 @@ import torch -from shark_turbine import aot +from iree.turbine import aot from sharktank import kernels from sharktank.types import layout_utils diff --git a/sharktank/tests/kernels/mmt_block_scaled_offset_q4_test.py b/sharktank/tests/kernels/mmt_block_scaled_offset_q4_test.py index d9fc7370a..dca474446 100644 --- a/sharktank/tests/kernels/mmt_block_scaled_offset_q4_test.py +++ b/sharktank/tests/kernels/mmt_block_scaled_offset_q4_test.py @@ -13,7 +13,7 @@ import torch -from shark_turbine import aot +from iree.turbine import aot from sharktank import kernels from sharktank.types import layout_utils diff --git a/sharktank/tests/kernels/mmt_block_scaled_q8_test.py b/sharktank/tests/kernels/mmt_block_scaled_q8_test.py index f3fdf2ed9..08aa8d179 100644 --- a/sharktank/tests/kernels/mmt_block_scaled_q8_test.py +++ b/sharktank/tests/kernels/mmt_block_scaled_q8_test.py @@ -13,7 +13,7 @@ import torch -from shark_turbine import aot +from iree.turbine import aot from sharktank import kernels diff --git a/sharktank/tests/kernels/mmt_super_block_scaled_offset_q4_test.py b/sharktank/tests/kernels/mmt_super_block_scaled_offset_q4_test.py index 41c04106d..01272553a 100644 --- a/sharktank/tests/kernels/mmt_super_block_scaled_offset_q4_test.py +++ b/sharktank/tests/kernels/mmt_super_block_scaled_offset_q4_test.py @@ -13,7 +13,7 @@ import torch -from shark_turbine import aot +from iree.turbine import aot from sharktank import kernels from sharktank.types import layout_utils diff --git a/sharktank/tests/kernels/mmtfp_test.py b/sharktank/tests/kernels/mmtfp_test.py index 281498f90..e2c36e4ac 100644 --- a/sharktank/tests/kernels/mmtfp_test.py +++ b/sharktank/tests/kernels/mmtfp_test.py @@ -13,7 +13,7 @@ import torch -from shark_turbine import aot +from iree.turbine import aot from sharktank import kernels diff --git a/sharktank/tests/kernels/pooling_nchw_sum_test.py b/sharktank/tests/kernels/pooling_nchw_sum_test.py index 5c4e8ac0a..205391d96 100644 --- a/sharktank/tests/kernels/pooling_nchw_sum_test.py +++ b/sharktank/tests/kernels/pooling_nchw_sum_test.py @@ -13,7 +13,7 @@ import torch -from shark_turbine import aot +from iree.turbine import aot from sharktank import kernels from sharktank.ops.qconv_impls import _pad_last_2d diff --git a/sharktank/tests/layers/sharded_conv2d_with_iree_test.py b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py index 9d8b81c62..2a6ecace2 100644 --- a/sharktank/tests/layers/sharded_conv2d_with_iree_test.py +++ b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py @@ -9,7 +9,7 @@ from pathlib import Path import tempfile import torch -from shark_turbine import aot +from iree.turbine import aot from sharktank.models.punet.layers import Conv2DLayer from sharktank import ops from sharktank.types import ( diff --git a/sharktank/tests/models/llama/moe_block_test.py b/sharktank/tests/models/llama/moe_block_test.py index 53706f1bd..edf1d9d97 100644 --- a/sharktank/tests/models/llama/moe_block_test.py +++ b/sharktank/tests/models/llama/moe_block_test.py @@ -8,7 +8,7 @@ from typing import List import torch -from shark_turbine.aot import * +from iree.turbine.aot import * from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock from sharktank import ops diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py index 4638df312..bdace4972 100644 --- a/sharktank/tests/models/llama/sharded_llama_test.py +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -24,7 +24,7 @@ import tempfile import torch from copy import deepcopy -from shark_turbine.aot import FxProgramsBuilder, export +from iree.turbine.aot import FxProgramsBuilder, export import iree.runtime from pathlib import Path diff --git a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py index 82638a495..86bb41c71 100644 --- a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py +++ b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py @@ -10,7 +10,7 @@ import torch -from shark_turbine import aot +from iree.turbine import aot from sharktank.models.punet.testing import make_resnet_block_2d_theta from sharktank.models.punet.layers import ResnetBlock2D from sharktank.models.punet.sharding import ResnetBlock2DSplitOutputChannelsSharding diff --git a/sharktank/tests/types/dataset_test.py b/sharktank/tests/types/dataset_test.py index 353bacbb0..4494eab2f 100644 --- a/sharktank/tests/types/dataset_test.py +++ b/sharktank/tests/types/dataset_test.py @@ -11,7 +11,7 @@ import torch -from shark_turbine.aot import ExternalTensorTrait +from iree.turbine.aot import ExternalTensorTrait from sharktank.types import * diff --git a/turbine-requirements.txt b/turbine-requirements.txt index 70078ffa4..0d0dc7619 100644 --- a/turbine-requirements.txt +++ b/turbine-requirements.txt @@ -1 +1 @@ --e "git+https://github.com/iree-org/iree-turbine.git#egg=shark-turbine" +-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" From baf567c7a653d3de17bb7239a2c97fa58d1d5eb8 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 7 Oct 2024 17:24:12 -0400 Subject: [PATCH 02/15] [tuner] Fixes for python 3.12 (#258) Account for python changes. --- tuner/tuner/candidate_gen.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 9bddb5e89..8faf70f85 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -159,10 +159,13 @@ class Configuration: waves_per_eu: int -class MlirRegex(str, Enum): +class MlirRegex(Enum): ssa_value = r"%[a-zA-Z0-9-_]+" tensor_type = r"tensor<(([0-9]+x)+((f|i)[0-9]+))>" + def __str__(self) -> str: + return self.value + @staticmethod def dps_ins_two_args() -> str: return rf"ins\({MlirRegex.ssa_value}, {MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type}), (?P{MlirRegex.tensor_type})\)" @@ -259,7 +262,7 @@ def apply_configuration( def parse_tensor_type(tensor_type: str) -> ShapedType: - shape_match = re.search(MlirRegex.tensor_type, tensor_type) + shape_match = re.search(str(MlirRegex.tensor_type), tensor_type) assert shape_match shape_str = shape_match.group(1) From 4e2f35128eb634c05c6d0876b6672985caca9739 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Mon, 7 Oct 2024 18:09:28 -0400 Subject: [PATCH 03/15] Add dumping of reproducer for toy sharded llama test (#257) Add pytest CLI option for overriding the temporary test directory. If the option is provided dump MLIR, VMFB, function call arguments and results. Move numerical checks at the end of the test. Add some basic functionality to run multi-device IREE functions with tracing. --- sharktank/conftest.py | 55 ++- sharktank/sharktank/types/tensors.py | 8 + sharktank/sharktank/utils/iree.py | 189 +++++++++++ .../tests/models/llama/sharded_llama_test.py | 312 ++++++++---------- 4 files changed, 372 insertions(+), 192 deletions(-) create mode 100644 sharktank/sharktank/utils/iree.py diff --git a/sharktank/conftest.py b/sharktank/conftest.py index 459272d5d..1387b0611 100644 --- a/sharktank/conftest.py +++ b/sharktank/conftest.py @@ -6,7 +6,8 @@ from pathlib import Path import pytest -from typing import Optional +from pytest import FixtureRequest +from typing import Optional, Any # Tests under each top-level directory will get a mark. @@ -47,6 +48,15 @@ def pytest_addoption(parser): default=None, help="Exported model parameters. If not specified a temporary file will be used.", ) + parser.addoption( + "--prefix", + type=str, + default=None, + help=( + "Path prefix for test artifacts. " + "Other arguments may override this for specific values." + ), + ) parser.addoption( "--caching", action="store_true", @@ -55,21 +65,40 @@ def pytest_addoption(parser): ) -@pytest.fixture(scope="session") -def mlir_path(pytestconfig: pytest.Config) -> Optional[Path]: - return pytestconfig.getoption("mlir") +def set_fixture_from_cli_option( + request: FixtureRequest, + cli_option_name: str, + class_attribute_name: Optional[str] = None, +) -> Optional[Any]: + res = request.config.getoption(cli_option_name) + if request.cls is None: + return res + else: + if class_attribute_name is None: + class_attribute_name = cli_option_name + setattr(request.cls, class_attribute_name, res) + + +@pytest.fixture(scope="class") +def mlir_path(request: FixtureRequest) -> Optional[Path]: + return set_fixture_from_cli_option(request, "mlir", "mlir_path") + + +@pytest.fixture(scope="class") +def module_path(request: FixtureRequest) -> Optional[Path]: + return set_fixture_from_cli_option(request, "module", "module_path") -@pytest.fixture(scope="session") -def module_path(pytestconfig: pytest.Config) -> Optional[Path]: - return pytestconfig.getoption("module") +@pytest.fixture(scope="class") +def parameters_path(request: FixtureRequest) -> Optional[Path]: + return set_fixture_from_cli_option(request, "parameters", "parameters_path") -@pytest.fixture(scope="session") -def parameters_path(pytestconfig: pytest.Config) -> Optional[Path]: - return pytestconfig.getoption("parameters") +@pytest.fixture(scope="class") +def path_prefix(request: FixtureRequest) -> Optional[str]: + return set_fixture_from_cli_option(request, "prefix", "path_prefix") -@pytest.fixture(scope="session") -def caching(pytestconfig: pytest.Config) -> Optional[Path]: - return pytestconfig.getoption("caching") +@pytest.fixture(scope="class") +def caching(request: FixtureRequest) -> Optional[bool]: + return set_fixture_from_cli_option(request, "caching") diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 200800d44..324cc4331 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -26,6 +26,7 @@ import torch from torch import Tensor from torch.utils._pytree import register_pytree_node, SequenceKey +import torch.utils._pytree from ..utils.math import ceildiv from iree.turbine.aot import ( ExternalTensorTrait, @@ -48,6 +49,7 @@ "ReplicatedTensor", "ShardedTensor", "SplitPrimitiveTensor", + "torch_tree_flatten", "unbox_tensor", "UnreducedTensor", ] @@ -1360,3 +1362,9 @@ def flatten_with_keys_replicated_tensor(t: ReplicatedTensor): unflatten_fn=unflatten_replicated_tensor, flatten_with_keys_fn=flatten_with_keys_replicated_tensor, ) + + +def torch_tree_flatten(tree: tree_utils.Tree): + """Flatten a tree of tensors the same way they will be flattened during torch.export.export + if they are arguments or results of a function signature.""" + return torch.utils._pytree.tree_flatten(tree=tree) diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py new file mode 100644 index 000000000..7c666ff62 --- /dev/null +++ b/sharktank/sharktank/utils/iree.py @@ -0,0 +1,189 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import iree.runtime +from typing import List, Tuple, Optional, Union +from pathlib import Path +import torch +import numpy as np +import collections.abc +from collections import OrderedDict +from ..types.tensors import ( + AnyTensor, + InferenceTensor, + ShardedTensor, + DefaultPrimitiveTensor, + unbox_tensor, + torch_tree_flatten, +) +from .tree import Tree + + +def get_iree_devices(driver: str, device_count: int) -> List[iree.runtime.HalDevice]: + hal_driver = iree.runtime.get_driver(driver) + available_devices = hal_driver.query_available_devices() + if driver in ["local-task", "local-sync"]: + # Use the same actual device for all devices. + return [ + hal_driver.create_device(available_devices[0]) for _ in range(device_count) + ] + else: + return [ + hal_driver.create_device(available_devices[i]) for i in range(device_count) + ] + + +def load_iree_module( + module_path: str, + devices: List[iree.runtime.HalDevice], + parameters_path: Optional[str] = None, +) -> Tuple[iree.runtime.VmModule, iree.runtime.VmContext, iree.runtime.VmInstance]: + """The VmContext and VmInstance need to outlive the VmModule and any device + buffers.""" + vm_instance = iree.runtime.VmInstance() + hal_module = iree.runtime.create_hal_module(instance=vm_instance, devices=devices) + modules = [hal_module] + if parameters_path is not None: + params_path = Path(parameters_path) + parameter_index = iree.runtime.ParameterIndex() + if len(devices) > 1: + # TODO: make IREE able to load the parameters from the top parameter file + # without having to specify the parameter file for each shard separately. + for i in range(len(devices)): + parameter_index.load( + file_path=str( + Path(params_path).with_suffix(f".rank{i}{params_path.suffix}") + ) + ) + else: + parameter_index.load(file_path=str(params_path)) + parameter_provider = parameter_index.create_provider(scope="model") + parameters_module = iree.runtime.create_io_parameters_module( + vm_instance, parameter_provider + ) + modules.append(parameters_module) + vm_module = iree.runtime.VmModule.mmap(vm_instance, str(module_path)) + modules.append(vm_module) + vm_context = iree.runtime.VmContext(instance=vm_instance, modules=modules) + return vm_module, vm_context, vm_instance + + +def run_iree_module_function( + module: iree.runtime.VmModule, + vm_context: iree.runtime.VmContext, + args: List[iree.runtime.DeviceArray], + driver: str, + function_name: str = "main", + trace_path_prefix: Optional[str] = None, +) -> List[iree.runtime.DeviceArray]: + """Run IREE module function with optional tracing of arguments/results.""" + vm_function = module.lookup_function(function_name) + invoker = iree.runtime.FunctionInvoker( + vm_context=vm_context, + # TODO: rework iree.runtime.FunctionInvoker interface for multiple devices. + # This works, but does not look right. + device=iree.runtime.get_device(driver, cache=False), + vm_function=vm_function, + ) + if trace_path_prefix is not None: + for i, arg in enumerate(args): + np.save(f"{trace_path_prefix}{function_name}_arg{i}.npy", arg.to_host()) + results = invoker(*args) + if isinstance(results, iree.runtime.DeviceArray): + results = (results,) + + if trace_path_prefix is not None: + for i, arg in enumerate(args): + np.save( + f"{trace_path_prefix}{function_name}_arg_post_call{i}.npy", + arg.to_host(), + ) + for i, arg in enumerate(results): + np.save(f"{trace_path_prefix}{function_name}_result{i}.npy", arg.to_host()) + return results + + +def prepare_iree_module_function_args( + args: List[Union[AnyTensor, List[AnyTensor]]], devices: List[iree.runtime.HalDevice] +) -> List[iree.runtime.DeviceArray]: + """Flatten composite tensors into their parts and place them on devices. + Sharded tensors become a list of their shards while placing them onto their + corresponding device. + All unsharded tensors go on device 0. + """ + res = [] + for arg in args: + if isinstance(arg, ShardedTensor): + assert len(devices) == len(arg.shards) + res.extend( + [ + prepare_iree_module_function_args([shard], [device])[0] + for shard, device in zip(arg.shards, devices) + ] + ) + elif isinstance(arg, (DefaultPrimitiveTensor, torch.Tensor)): + res.append( + iree.runtime.asdevicearray( + devices[0], unbox_tensor(arg).to("cpu").numpy() + ) + ) + else: + assert isinstance(arg, collections.abc.Sequence) + res.extend(prepare_iree_module_function_args(arg, devices)) + return res + + +def flatten_for_iree_signature(tree: Tree) -> List[torch.Tensor]: + """Flatten a tree of arguments or results for an IREE call. + E.g. sharded tensors gets flattened into their shards.""" + + return torch_tree_flatten(tree)[0] + + +def call_torch_module_function( + module: torch.nn.Module, + function_name: str, + kwargs: OrderedDict, + trace_path_prefix: Optional[str] = None, +): + """Call a torch module function with optional tracing. + For tracing the arguments/results are flattened to match IREE's signature.""" + assert isinstance( + kwargs, OrderedDict + ), "Make sure when flattening the order is preserved" + if trace_path_prefix is not None: + flat_args = flatten_for_iree_signature(kwargs) + for i, arg in enumerate(flat_args): + np.save( + f"{trace_path_prefix}{function_name}_arg{i}.npy", + arg.to("cpu").numpy(), + ) + res = getattr(module, function_name)(**kwargs) + if trace_path_prefix is not None: + flat_args = flatten_for_iree_signature(kwargs) + for i, arg in enumerate(flat_args): + np.save( + f"{trace_path_prefix}{function_name}_arg{i}.npy", + arg.to("cpu").numpy(), + ) + results = ( + (res,) + if isinstance( + res, + ( + torch.Tensor, + InferenceTensor, + ), + ) + else res + ) + flat_results = flatten_for_iree_signature(results) + for i, result in enumerate(flat_results): + np.save( + f"{trace_path_prefix}{function_name}_result{i}.npy", + result.to("cpu").numpy(), + ) + return res diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py index bdace4972..4d34dc704 100644 --- a/sharktank/tests/models/llama/sharded_llama_test.py +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -5,115 +5,40 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import unittest -from typing import Any, List, Tuple, Union, OrderedDict -import collections.abc +import pytest +from typing import Any, List, Tuple, OrderedDict from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 import sharktank.ops as ops from sharktank.types import ( unbox_tensor, - ShardedTensor, - DefaultPrimitiveTensor, Dataset, - AnyTensor, ) from sharktank.models.llama.testing import make_random_llama_theta from sharktank.models.llama.sharding import shard_theta from sharktank.layers.configs import LlamaHParams from sharktank.utils.math import round_up_to_multiple_of from sharktank.utils import iterables_equal +from sharktank.utils.iree import ( + get_iree_devices, + load_iree_module, + run_iree_module_function, + prepare_iree_module_function_args, + call_torch_module_function, +) import tempfile import torch from copy import deepcopy from iree.turbine.aot import FxProgramsBuilder, export import iree.runtime -from pathlib import Path - - -def get_iree_devices(driver: str, device_count: int) -> List[iree.runtime.HalDevice]: - hal_driver = iree.runtime.get_driver(driver) - available_devices = hal_driver.query_available_devices() - # Use the same actual device for all devices. - return [hal_driver.create_device(available_devices[0]) for _ in range(device_count)] - - -def load_iree_module( - module_path: str, - parameters_path: str, - devices: List[iree.runtime.HalDevice], -) -> Tuple[iree.runtime.VmModule, iree.runtime.VmContext, iree.runtime.VmInstance]: - params_path = Path(parameters_path) - # TODO: make IREE able to load the parameters from the top parameter file - # without having to specify the parameter file for each shard separately. - parameter_index = iree.runtime.ParameterIndex() - for i in range(len(devices)): - parameter_index.load( - file_path=str( - Path(params_path).with_suffix(f".rank{i}{params_path.suffix}") - ) - ) - parameter_provider = parameter_index.create_provider(scope="model") - vm_instance = iree.runtime.VmInstance() - parameters_module = iree.runtime.create_io_parameters_module( - vm_instance, parameter_provider - ) - vm_module = iree.runtime.VmModule.mmap(vm_instance, str(module_path)) - hal_module = iree.runtime.create_hal_module(instance=vm_instance, devices=devices) - vm_context = iree.runtime.VmContext( - instance=vm_instance, modules=(hal_module, parameters_module, vm_module) - ) - return vm_module, vm_context, vm_instance - - -def run_iree_module_function( - module: iree.runtime.VmModule, - vm_context: iree.runtime.VmContext, - function_name: str, - args: List[iree.runtime.DeviceArray], - driver: str, -) -> List[iree.runtime.DeviceArray]: - vm_function = module.lookup_function(function_name) - invoker = iree.runtime.FunctionInvoker( - vm_context=vm_context, - # TODO: rework iree.runtime.FunctionInvoker interface for multiple devices. - # This works, but does not look right. - device=iree.runtime.get_device(driver, cache=False), - vm_function=vm_function, - ) - res = invoker(*args) - if isinstance(res, iree.runtime.DeviceArray): - res = (res,) - return res - - -def prepare_iree_module_function_args( - args: List[Union[AnyTensor, List[AnyTensor]]], devices: List[iree.runtime.HalDevice] -) -> List[iree.runtime.DeviceArray]: - res = [] - for arg in args: - if isinstance(arg, ShardedTensor): - assert len(devices) == len(arg.shards) - res.extend( - [ - prepare_iree_module_function_args([shard], [device])[0] - for shard, device in zip(arg.shards, devices) - ] - ) - elif isinstance(arg, (DefaultPrimitiveTensor, torch.Tensor)): - res.append( - iree.runtime.asdevicearray( - devices[0], unbox_tensor(arg).to("cpu").numpy() - ) - ) - else: - assert isinstance(arg, collections.abc.Sequence) - res.extend(prepare_iree_module_function_args(arg, devices)) - return res +import numpy as np +import os def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]: return [torch.tensor(tensor.to_host()) for tensor in tensors] +@pytest.mark.usefixtures("caching", "path_prefix") class ShardedLlamaTest(unittest.TestCase): def setUp(self): torch.random.manual_seed(123456) @@ -304,25 +229,44 @@ def testExportAndRunToySizedModelWithIree(self): """Test exporting to MLIR and compiling with IREE the sharded Llama model. Test numerical accuracy of the IREE module against PyTorch.""" - with tempfile.TemporaryDirectory() as temp_dir: - sharded_theta = shard_theta(self.theta, self.sharded_config) - sharded_theta.rename_tensors_to_paths() - sharded_dataset = Dataset({}, sharded_theta) - sharded_parameters_path = f"{temp_dir}/parameters.irpa" - sharded_dataset.save(sharded_parameters_path) - sharded_dataset = Dataset.load(sharded_parameters_path, mmap=False) - iree_driver = "local-task" - - model = PagedLlamaModelV1(self.theta, self.config) - sharded_model = PagedLlamaModelV1( - sharded_dataset.root_theta, self.sharded_config + if self.path_prefix is not None: + self.runTestExportAndRunToySizedModelWithIree( + path_prefix=self.path_prefix, dump_enabled=True ) - sharded_fxb = FxProgramsBuilder(sharded_model) + else: + with tempfile.TemporaryDirectory() as temp_dir: + self.runTestExportAndRunToySizedModelWithIree( + path_prefix=f"{temp_dir}/", dump_enabled=False + ) + + def runTestExportAndRunToySizedModelWithIree( + self, path_prefix: str, dump_enabled: bool + ): + sharded_theta = shard_theta(self.theta, self.sharded_config) + sharded_theta.rename_tensors_to_paths() + sharded_dataset = Dataset({}, sharded_theta) + sharded_parameters_path = f"{path_prefix}parameters.irpa" + sharded_dataset.save(sharded_parameters_path) + sharded_dataset = Dataset.load(sharded_parameters_path, mmap=False) + iree_driver = "local-task" + + model = PagedLlamaModelV1(self.theta, self.config) + sharded_model = PagedLlamaModelV1( + sharded_dataset.root_theta, self.sharded_config + ) + ( + _, + sharded_prefill_args, + ) = self.make_equal_unsharded_and_sharded_prefill_args(model, sharded_model) + ( + _, + sharded_decode_args, + ) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model) - ( - _, - sharded_prefill_args, - ) = self.make_equal_unsharded_and_sharded_prefill_args(model, sharded_model) + iree_module_path = f"{path_prefix}program.vmfb" + if not self.caching or not os.path.exists(iree_module_path): + # Export and compile the IREE module. + sharded_fxb = FxProgramsBuilder(sharded_model) @sharded_fxb.export_program( name="prefill", args=tuple(), kwargs=sharded_prefill_args @@ -330,10 +274,6 @@ def testExportAndRunToySizedModelWithIree(self): def _(model, *args, **kwargs) -> torch.Tensor: return model.prefill(*args, **kwargs) - ( - _, - sharded_decode_args, - ) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model) # TODO: remove strict=False when # https://github.com/pytorch/pytorch/issues/136757 # is resolved. @@ -346,91 +286,105 @@ def _(model, *args, **kwargs) -> torch.Tensor: def _(model, *args, **kwargs) -> torch.Tensor: return model.decode(*args, **kwargs) - # Compile the IREE module. output = export(sharded_fxb) - output.save_mlir(f"{temp_dir}/program.mlir") + if dump_enabled: + output.save_mlir(f"{path_prefix}program.mlir") output.session.set_flags( *[ f"--iree-hal-target-device=llvm-cpu[{i}]" for i in range(self.sharded_config.tensor_parallelism_size) ] ) - iree_module_path = f"{temp_dir}/program.vmfb" output.compile( save_to=iree_module_path, target_backends=None, ) - iree_devices = get_iree_devices( - driver=iree_driver, - device_count=self.sharded_config.tensor_parallelism_size, - ) - iree_module, vm_context, vm_instance = load_iree_module( - module_path=iree_module_path, - devices=iree_devices, - parameters_path=sharded_parameters_path, - ) + iree_devices = get_iree_devices( + driver=iree_driver, + device_count=self.sharded_config.tensor_parallelism_size, + ) + iree_module, vm_context, vm_instance = load_iree_module( + module_path=iree_module_path, + devices=iree_devices, + parameters_path=sharded_parameters_path, + ) - # Check IREE's prefill step is close to torch. - prefill_iree_args = prepare_iree_module_function_args( - args=deepcopy(sharded_prefill_args).values(), devices=iree_devices - ) - prefill_iree_result = run_iree_module_function( - args=prefill_iree_args, - function_name="prefill", - module=iree_module, - vm_context=vm_context, - driver=iree_driver, - ) - prefill_iree_result = iree_to_torch(*prefill_iree_result) - assert len(prefill_iree_result) == 1 - expected_prefill_result = sharded_model.prefill(**sharded_prefill_args) - # TODO: Although, not entirely wrong, investigate why this accuracy is that - # low for fp32 (atol=0.0011, rtol=0.013). - torch.testing.assert_close( - prefill_iree_result[0], - expected_prefill_result, - ) - prefill_iree_cache_state_shards = prefill_iree_args[ - -self.config.tensor_parallelism_size - 1 : - ] - prefill_iree_cache_state_shards = iree_to_torch( - *prefill_iree_cache_state_shards - ) - for actual_cache_state_shard, expected_cache_state_shard in zip( - prefill_iree_cache_state_shards, - sharded_prefill_args["cache_state"][0].shards, - ): - # TODO: debug inaccuracy. - torch.testing.assert_close( - actual_cache_state_shard, unbox_tensor(expected_cache_state_shard) - ) + # Run prefill step. + prefill_iree_args = prepare_iree_module_function_args( + args=deepcopy(sharded_prefill_args).values(), devices=iree_devices + ) + for i, arg in enumerate(prefill_iree_args): + np.save(f"{path_prefix}prefill_arg{i}.npy", arg.to_host()) + prefill_iree_result = run_iree_module_function( + args=prefill_iree_args, + function_name="prefill", + module=iree_module, + vm_context=vm_context, + driver=iree_driver, + trace_path_prefix=path_prefix if dump_enabled else None, + ) + prefill_iree_result = iree_to_torch(*prefill_iree_result) + assert len(prefill_iree_result) == 1 + expected_prefill_result = call_torch_module_function( + module=sharded_model, + function_name="prefill", + kwargs=sharded_prefill_args, + trace_path_prefix=f"{path_prefix}expected_" if dump_enabled else None, + ) + prefill_iree_cache_state_shards = prefill_iree_args[ + -self.config.tensor_parallelism_size - 1 : + ] + prefill_iree_cache_state_shards = iree_to_torch( + *prefill_iree_cache_state_shards + ) - # Check IREE's decode step is close to torch. - decode_iree_args = prepare_iree_module_function_args( - args=deepcopy(sharded_decode_args).values(), devices=iree_devices - ) - decode_iree_result = run_iree_module_function( - args=decode_iree_args, - function_name="decode", - module=iree_module, - vm_context=vm_context, + # Run decode step. + decode_iree_args = prepare_iree_module_function_args( + args=deepcopy(sharded_decode_args).values(), devices=iree_devices + ) + decode_iree_result = run_iree_module_function( + args=decode_iree_args, + function_name="decode", + module=iree_module, + vm_context=vm_context, + driver=iree_driver, + trace_path_prefix=path_prefix if dump_enabled else None, + ) + decode_iree_result = iree_to_torch(*decode_iree_result) + expected_decode_result = call_torch_module_function( + module=sharded_model, + function_name="decode", + kwargs=sharded_decode_args, + trace_path_prefix=f"{path_prefix}expected_" if dump_enabled else None, + ) + decode_iree_cache_state_shards = decode_iree_args[ + -self.config.tensor_parallelism_size - 1 : + ] + decode_iree_cache_state_shards = iree_to_torch(*decode_iree_cache_state_shards) + + # Check IREE's numerical correctness against PyTorch. + # TODO: Although, not entirely wrong, investigate why this accuracy is that + # low for fp32 (atol=0.0011, rtol=0.013). + torch.testing.assert_close( + prefill_iree_result[0], + expected_prefill_result, + ) + for actual_cache_state_shard, expected_cache_state_shard in zip( + prefill_iree_cache_state_shards, + sharded_prefill_args["cache_state"][0].shards, + ): + # TODO: debug inaccuracy. + torch.testing.assert_close( + actual_cache_state_shard, unbox_tensor(expected_cache_state_shard) ) - decode_iree_result = iree_to_torch(*decode_iree_result) - expected_decode_result = sharded_model.decode(**sharded_decode_args) + # TODO: debug inaccuracy. + torch.testing.assert_close(decode_iree_result[0], expected_decode_result) + for actual_cache_state_shard, expected_cache_state_shard in zip( + decode_iree_cache_state_shards, + sharded_decode_args["cache_state"][0].shards, + ): # TODO: debug inaccuracy. - torch.testing.assert_close(decode_iree_result[0], expected_decode_result) - decode_iree_cache_state_shards = decode_iree_args[ - -self.config.tensor_parallelism_size - 1 : - ] - decode_iree_cache_state_shards = iree_to_torch( - *decode_iree_cache_state_shards + torch.testing.assert_close( + actual_cache_state_shard, unbox_tensor(expected_cache_state_shard) ) - for actual_cache_state_shard, expected_cache_state_shard in zip( - decode_iree_cache_state_shards, - sharded_decode_args["cache_state"][0].shards, - ): - # TODO: debug inaccuracy. - torch.testing.assert_close( - actual_cache_state_shard, unbox_tensor(expected_cache_state_shard) - ) From b55065a2b747ed4a4b755a9f34d04e6bcabdfa4d Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Wed, 9 Oct 2024 12:20:55 -0400 Subject: [PATCH 04/15] Enable check for sharded Conv2D test (#263) The fix https://github.com/iree-org/iree-turbine/pull/205 solves the issue with this test. Xfail the Unet Resnet block test with maybe low accuracy. --- .../layers/sharded_conv2d_with_iree_test.py | 14 +++++------ .../sharded_resnet_block_with_iree_test.py | 24 ++++++++++++------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/sharktank/tests/layers/sharded_conv2d_with_iree_test.py b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py index 2a6ecace2..9b29e5761 100644 --- a/sharktank/tests/layers/sharded_conv2d_with_iree_test.py +++ b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py @@ -173,14 +173,12 @@ def run_test_sharded_conv2d_with_iree( ) assert len(actual_result.shards) == len(expected_result.shards) assert actual_result.shard_dim == expected_result.shard_dim - # TODO: reenable this check once numerical issues are resolved. - # See https://github.com/iree-org/iree/issues/18283 - # for actual_shard, expected_shard in zip( - # actual_result.shards, expected_result.shards - # ): - # torch.testing.assert_close( - # unbox_tensor(actual_shard), unbox_tensor(expected_shard) - # ) + for actual_shard, expected_shard in zip( + actual_result.shards, expected_result.shards + ): + torch.testing.assert_close( + unbox_tensor(actual_shard), unbox_tensor(expected_shard) + ) def test_sharded_conv2d_with_iree( diff --git a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py index 86bb41c71..581584369 100644 --- a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py +++ b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py @@ -19,6 +19,7 @@ import iree.runtime from typing import List, Optional import os +import pytest vm_context: iree.runtime.VmContext = None @@ -207,19 +208,26 @@ def run_test_sharded_resnet_block_with_iree( parameters_path=parameters_path, ) assert len(actual_result.shards) == len(expected_result.shards) - # TODO: reenable this check once numerical issues are resolved. - # See https://github.com/iree-org/iree/issues/18283 - # for actual_shard, expected_shard in zip( - # actual_result.shards, expected_result.shards - # ): - # torch.testing.assert_close( - # unbox_tensor(actual_shard), unbox_tensor(expected_shard) - # ) + # TODO: reenable this test once numerical issues are resolved. + # The absolute accuracy is > 0.00042. Is this good enough? + # Maybe add a test with fp64, where if the accuracy is high would give us more + # confidence that fp32 is also OK. + for actual_shard, expected_shard in zip( + actual_result.shards, expected_result.shards + ): + torch.testing.assert_close( + unbox_tensor(actual_shard), unbox_tensor(expected_shard) + ) global vm_context del vm_context +@pytest.mark.xfail( + reason="Maybe numerical issues with low accuracy.", + strict=True, + raises=AssertionError, +) def test_sharded_resnet_block_with_iree( mlir_path: Optional[Path], module_path: Optional[Path], From a0d5d10542a698db9ed96d1d73da663f87b84ded Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 11 Oct 2024 13:40:07 +0200 Subject: [PATCH 05/15] Cleanup workflow files (#270) * Removes extra path from command line * Adds quotes to make call compatible with Windows CI * Removes no longer required deps --- .github/workflows/ci_linux_x64-libshortfin.yml | 11 ++++------- .github/workflows/ci_linux_x64_nogil-libshortfin.yml | 5 ++--- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml index bdb4620be..d7450cbe7 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci_linux_x64-libshortfin.yml @@ -41,7 +41,6 @@ jobs: run: | sudo apt update sudo apt install clang lld cmake ninja-build - sudo apt install libspdlog-dev libxtensor-dev - name: Checkout repository uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 @@ -89,9 +88,8 @@ jobs: -DCMAKE_CXX_COMPILER=clang++-18 \ -DCMAKE_LINKER_TYPE=LLD \ -DSHORTFIN_BUNDLE_DEPS=ON \ - -DSHORTFIN_IREE_SOURCE_DIR=${{ env.IREE_REPO_DIR }} \ - -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ - .. + -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ + -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON cmake --build build --target all pip install -v -e build/ @@ -113,10 +111,9 @@ jobs: -DCMAKE_C_COMPILER=clang-18 \ -DCMAKE_CXX_COMPILER=clang++-18 \ -DCMAKE_LINKER_TYPE=LLD \ - -DSHORTFIN_IREE_SOURCE_DIR=${{ env.IREE_REPO_DIR }} \ + -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ -DSHORTFIN_HAVE_AMDGPU=OFF \ -DSHORTFIN_BUILD_STATIC=ON \ - -DSHORTFIN_BUILD_DYNAMIC=ON \ - .. + -DSHORTFIN_BUILD_DYNAMIC=ON cmake --build build-host-only --target all diff --git a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml index 08f5e62da..12efdadda 100644 --- a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml @@ -86,9 +86,8 @@ jobs: -DCMAKE_CXX_COMPILER=clang++-18 \ -DCMAKE_LINKER_TYPE=LLD \ -DSHORTFIN_BUNDLE_DEPS=ON \ - -DSHORTFIN_IREE_SOURCE_DIR=${{ env.IREE_REPO_DIR }} \ - -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ - .. + -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ + -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON cmake --build build --target all pip install -v -e build/ From e4bcf99bb2387a2280da3dc26993a73878bad8b8 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Fri, 11 Oct 2024 13:01:53 -0400 Subject: [PATCH 06/15] [tuner] Fix mfma constructor arguments (#266) --- tuner/tuner/candidate_gen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 8faf70f85..16f0cf724 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -452,11 +452,11 @@ def generate_solutions(problem_size: ProblemSize, num_subgrups: int): lookup(subgroup_size), [lookup(wg_x), lookup(wg_y), lookup(wg_z)], MfmaIntrinsic( - problem_size.lhs_type.element_type, + problem_size.res_type.element_type, lookup(intrinsic_mn), lookup(intrinsic_mn), lookup(intrinsic_k), - problem_size.res_type.element_type, + problem_size.lhs_type.element_type, ), [lookup(m), lookup(n), lookup(k)], lookup(sg_m_cnt), From 468fb29ee6999473845d59af2ad17cbd834a2409 Mon Sep 17 00:00:00 2001 From: Mihaescu Vlad <52869843+mihaescuvlad@users.noreply.github.com> Date: Fri, 11 Oct 2024 20:26:30 +0300 Subject: [PATCH 07/15] [tuner] Use JSON for benchmark output (#256) ### Notes - Adds `extract_benchmark_from_run_result` method to help in fetching the "benchmarks" data - Updates `IREEBenchmarkResult` model to reflect that the result is no longer stored as string but as a list of benchmarks - Updates the parsing of dispatches and models to read from Json ### Testing - Updates tests to verify that `get_mean_time` functions as expected - Updates tests to verify that Json data is properly parsed and processed --- tuner/examples/dispatch/dispatch_tuner.py | 1 + tuner/examples/punet/punet_autotune.py | 6 +- tuner/tuner/libtuner.py | 85 +++++++++--- tuner/tuner/libtuner_test.py | 161 ++++++++++++++++------ 4 files changed, 192 insertions(+), 61 deletions(-) diff --git a/tuner/examples/dispatch/dispatch_tuner.py b/tuner/examples/dispatch/dispatch_tuner.py index 98086fbbb..3c2d77f64 100644 --- a/tuner/examples/dispatch/dispatch_tuner.py +++ b/tuner/examples/dispatch/dispatch_tuner.py @@ -58,6 +58,7 @@ def get_dispatch_benchmark_command( f"--module={compiled_vmfb_path.resolve()}", "--batch_size=1000", "--benchmark_repetitions=3", + "--benchmark_format=json", ] return command diff --git a/tuner/examples/punet/punet_autotune.py b/tuner/examples/punet/punet_autotune.py index b78989991..3503c86df 100644 --- a/tuner/examples/punet/punet_autotune.py +++ b/tuner/examples/punet/punet_autotune.py @@ -58,8 +58,7 @@ def get_dispatch_benchmark_command( "--hip_allow_inline_execution=true", "--batch_size=1000", "--benchmark_repetitions=3", - f"--benchmark_out=dispatch_{candidate_tracker.candidate_id}_bm.json", - "--benchmark_out_format=json", + "--benchmark_format=json", ] return command @@ -110,8 +109,7 @@ def get_model_benchmark_command( "--input=2x6xf16", "--input=1xf16", "--benchmark_repetitions=5", - f"--benchmark_out=model_{candidate_tracker.candidate_id}_bm.json", - "--benchmark_out_format=json", + "--benchmark_format=json", ] return command diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index 30ce732bd..91c7b417a 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -36,6 +36,7 @@ from typing import Type, Optional, Callable, Iterable, Any import pickle import random +import json from abc import ABC, abstractmethod import iree.runtime as ireert from . import candidate_gen @@ -236,20 +237,48 @@ class ParsedDisptachBenchmarkResult: class IREEBenchmarkResult: # Default format follows output of iree-benchmark-module candidate_id: int - result_str: str - def get_mean_time(self) -> Optional[float]: - if not self.result_str: - return None - pattern = r"process_time/real_time_mean\s+([\d.]+)\s\w{2}" - match = re.search(pattern, self.result_str) - if not match: - return None - try: - return float(match.group(1)) - except ValueError: + # A list of dictionaries, each representing a benchmark result + # Each dictionary contains fields like: aggregate_name: string, real_time: float, cpu_time: float, time_unit: str, repetitions: int, etc. + result_json: list[dict[str, Any]] + + def get_mean_time_us(self) -> Optional[float]: + """Compute the mean time (in microseconds) for all of the benchmarks""" + if not self.result_json: return None + mean_benchmark = self.find_mean_benchmark(self.result_json) + + if mean_benchmark: + real_time = mean_benchmark.get("real_time") + time_unit = mean_benchmark.get("time_unit") + + if real_time is not None: + return self.unit_to_microseconds(real_time, time_unit) + + return None + + @staticmethod + def find_mean_benchmark(result_json: list[dict[str, Any]]) -> Optional[dict]: + for benchmark in result_json: + if benchmark.get("aggregate_name") == "mean": + return benchmark + + return None + + @staticmethod + def unit_to_microseconds(real_time: float, time_unit: str) -> float: + unit_conversions = { + "s": 1e6, + "ms": 1e3, + "us": 1, + "ns": 1e-3, + } + + assert time_unit in unit_conversions, f"Unsupported time unit: {time_unit}" + + return real_time * unit_conversions[time_unit] + def generate_display_DBR(candidate_id: int, mean_time: float) -> str: """Generate dispatch_benchmark_result string for displaying""" @@ -619,6 +648,26 @@ def multiprocess_progress_wrapper( return results +def extract_benchmark_from_run_result( + run_result: RunResult, +) -> Optional[list[dict[str, Any]]]: + """Extract the benchmark from the result JSON""" + if run_result.process_res and run_result.process_res.stdout: + try: + result_json = json.loads(run_result.process_res.stdout) + + return result_json.get("benchmarks", None) + except json.JSONDecodeError as e: + handle_error( + condition=True, + msg=f"Failed to parse JSON from stdout: {e}", + error_type=ValueError, + exit_program=True, + ) + + return None + + def numerical_sort_key(path: Path) -> tuple[int | float, str]: """ Define a sort key function that splits the filename into a numeric and a string part. @@ -896,9 +945,9 @@ def parse_dispatch_benchmark_results( incomplete_list.append(candidate_id) continue - res_str = process_res.stdout - res = IREEBenchmarkResult(candidate_id, res_str) - benchmark_time = res.get_mean_time() + res_json = extract_benchmark_from_run_result(benchmark_result.run_result) + res = IREEBenchmarkResult(candidate_id, res_json) + benchmark_time = res.get_mean_time_us() assert benchmark_time is not None candidate_trackers[candidate_id].first_benchmark_time = benchmark_time candidate_trackers[ @@ -1185,9 +1234,9 @@ def parse_model_benchmark_results( baseline_time = None continue - result_str = process_res.stdout - res = IREEBenchmarkResult(candidate_id, result_str) - benchmark_time = res.get_mean_time() + result_json = extract_benchmark_from_run_result(task_result.run_result) + res = IREEBenchmarkResult(candidate_id, result_json) + benchmark_time = res.get_mean_time_us() assert benchmark_time is not None # Record baseline benchmarking result and skip rest processes @@ -1328,7 +1377,7 @@ def benchmark_models( ) -def summerize_top_candidates( +def summarize_top_candidates( path_config: PathConfig, candidate_trackers: list[CandidateTracker] ): dump_list = [] diff --git a/tuner/tuner/libtuner_test.py b/tuner/tuner/libtuner_test.py index 3cbaa5ed0..36bda3bd5 100644 --- a/tuner/tuner/libtuner_test.py +++ b/tuner/tuner/libtuner_test.py @@ -6,11 +6,12 @@ import argparse import pytest +import json from unittest.mock import call, patch, MagicMock from . import libtuner """ -Usage: python -m pytest test_libtuner.py +Usage: python -m pytest libtuner_test.py """ @@ -57,34 +58,77 @@ def test_collision_handler(): def test_IREEBenchmarkResult_get(): - # Time is int - normal_str = r""" - ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ - Benchmark Time CPU Iterations UserCounters... - ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ - BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time 271 us 275 us 3000 items_per_second=3.65611k/s - BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time 274 us 275 us 3000 items_per_second=3.65481k/s - BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time 273 us 275 us 3000 items_per_second=3.65671k/s - BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_mean 274 us 275 us 3 items_per_second=3.65587k/s - BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_mean 275 us 275 us 3 items_per_second=3.65611k/s - BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_stddev 0.073 us 0.179 us 3 items_per_second=0.971769/s - BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_cv 0.03 % 0.07 % 3 items_per_second=0.03% - """ - res = libtuner.IREEBenchmarkResult(candidate_id=1, result_str=normal_str) - assert res.get_mean_time() == float(274) - - # Time is float + # Time is int in us + int_json = [{"aggregate_name": "mean", "real_time": 1, "time_unit": "us"}] + + res = libtuner.IREEBenchmarkResult(candidate_id=1, result_json=int_json) + assert res.get_mean_time_us() == float(1) + + # Time is float in us + float_json = [{"aggregate_name": "mean", "real_time": 123.45, "time_unit": "us"}] + + res = libtuner.IREEBenchmarkResult(candidate_id=2, result_json=float_json) + assert res.get_mean_time_us() == 123.45 + + # Time is in seconds + seconds_json = [{"aggregate_name": "mean", "real_time": 1.0, "time_unit": "s"}] + + res = libtuner.IREEBenchmarkResult(candidate_id=3, result_json=seconds_json) + assert res.get_mean_time_us() == 1.0 * 1e6 + + # Time is in miliseconds + miliseconds_json = [{"aggregate_name": "mean", "real_time": 1.0, "time_unit": "ms"}] + + res = libtuner.IREEBenchmarkResult(candidate_id=4, result_json=miliseconds_json) + assert res.get_mean_time_us() == 1.0 * 1e3 + + # Time is in nanoseconds + nanoseconds_json = [{"aggregate_name": "mean", "real_time": 1.0, "time_unit": "ns"}] + + res = libtuner.IREEBenchmarkResult(candidate_id=5, result_json=nanoseconds_json) + assert res.get_mean_time_us() == 1.0 * 1e-3 + + small_number_json = [ + { + "aggregate_name": "mean", + "real_time": 3.4591828516259519e-02, + "time_unit": "ms", + } + ] + + res = libtuner.IREEBenchmarkResult(candidate_id=6, result_json=small_number_json) + assert res.get_mean_time_us() == 34.591828516259519 + + # Invalid json: missing real_time + invalid_real_time_json = [{"aggregate_name": "mean", "real_time": None}] + res = libtuner.IREEBenchmarkResult( - candidate_id=2, - result_str="process_time/real_time_mean 123.45 us, process_time/real_time_mean 246.78 us", + candidate_id=7, result_json=invalid_real_time_json ) - assert res.get_mean_time() == 123.45 + assert res.get_mean_time_us() == None - # Invalid str - res = libtuner.IREEBenchmarkResult(candidate_id=3, result_str="hello world") - assert res.get_mean_time() == None - res = libtuner.IREEBenchmarkResult(candidate_id=4, result_str="") - assert res.get_mean_time() == None + # Invalid json: empty dictionary + res = libtuner.IREEBenchmarkResult(candidate_id=8, result_json={}) + assert res.get_mean_time_us() is None + + # Invalid json: invalid time unit + invalid_time_unit_json = [ + {"aggregate_name": "mean", "real_time": 1.0, "time_unit": "invalid_unit"} + ] + + with pytest.raises(AssertionError, match="Unsupported time unit: invalid_unit"): + res = libtuner.IREEBenchmarkResult( + candidate_id=9, result_json=invalid_time_unit_json + ) + res.get_mean_time_us() + + # Invalid json: missing aggregate_name + invalid_aggregate_name_json = [{"real_time": 1.0, "time_unit": "us"}] + + res = libtuner.IREEBenchmarkResult( + candidate_id=10, result_json=invalid_aggregate_name_json + ) + assert res.get_mean_time_us() is None def test_generate_display_BR(): @@ -110,15 +154,37 @@ def test_parse_dispatch_benchmark_results(): object.__setattr__(path_config, "specs_dir", spec_dir) mock_result_1 = MagicMock() - mock_result_1.run_result.process_res.stdout = "process_time/real_time_mean 100.0 us" + mock_json_1 = { + "benchmarks": [ + {"aggregate_name": "mean", "real_time": 100.0, "time_unit": "us"} + ] + } + mock_result_1.run_result.process_res.stdout = json.dumps(mock_json_1) mock_result_1.candidate_id = 1 mock_result_2 = MagicMock() - mock_result_2.run_result.process_res.stdout = "process_time/real_time_mean 200.0 us" + mock_json_2 = { + "benchmarks": [ + {"aggregate_name": "mean", "real_time": 200.0, "time_unit": "us"} + ] + } + mock_result_2.run_result.process_res.stdout = json.dumps(mock_json_2) mock_result_2.candidate_id = 2 mock_result_3 = MagicMock() - mock_result_3.run_result.process_res = None # Incomplete result + mock_json_3 = { + "benchmarks": [ + { + "aggregate_name": "mean", + "real_time": 3.4591828516259519e-02, + "time_unit": "ms", + } + ] + } + mock_result_3.run_result.process_res.stdout = json.dumps(mock_json_3) mock_result_3.candidate_id = 3 - benchmark_results = [mock_result_1, mock_result_2, mock_result_3] + mock_result_4 = MagicMock() + mock_result_4.run_result.process_res = None # Incomplete result + mock_result_4.candidate_id = 4 + benchmark_results = [mock_result_1, mock_result_2, mock_result_3, mock_result_4] candidate_trackers = [] for i in range(4): @@ -139,11 +205,18 @@ def test_parse_dispatch_benchmark_results(): candidate_mlir=libtuner.Path("/mock/mlir/path/2.mlir"), candidate_spec_mlir=libtuner.Path("/mock/base/dir/specs/2_spec.mlir"), ), + libtuner.ParsedDisptachBenchmarkResult( + candidate_id=3, + benchmark_time_in_seconds=34.591828516259519, + candidate_mlir=libtuner.Path("/mock/mlir/path/3.mlir"), + candidate_spec_mlir=libtuner.Path("/mock/base/dir/specs/3_spec.mlir"), + ), ] expected_dump_list = [ "1\tMean Time: 100.0\n", "2\tMean Time: 200.0\n", - "Candidate 3 not completed", + "3\tMean Time: 34.6\n", + "Candidate 4 not completed", ] parsed_results, dump_list = libtuner.parse_dispatch_benchmark_results( @@ -160,6 +233,10 @@ def test_parse_dispatch_benchmark_results(): assert candidate_trackers[2].spec_path == libtuner.Path( "/mock/base/dir/specs/2_spec.mlir" ) + assert candidate_trackers[3].first_benchmark_time == 34.591828516259519 + assert candidate_trackers[3].spec_path == libtuner.Path( + "/mock/base/dir/specs/3_spec.mlir" + ) def test_parse_model_benchmark_results(): @@ -180,22 +257,26 @@ def test_parse_model_benchmark_results(): # Setup mock data for task results result1 = MagicMock() - result1.run_result.process_res.stdout = "1.23" + result_json_1 = {"benchmarks": [{"real_time": 1.23}]} + result1.run_result.process_res.stdout = json.dumps(result_json_1) result1.candidate_id = 1 result1.device_id = "device1" result2 = MagicMock() - result2.run_result.process_res.stdout = "4.56" + result_json_2 = {"benchmarks": [{"real_time": 4.56}]} + result2.run_result.process_res.stdout = json.dumps(result_json_2) result2.candidate_id = 2 result2.device_id = "device2" result3 = MagicMock() - result3.run_result.process_res.stdout = "0.98" + result_json_3 = {"benchmarks": [{"real_time": 0.98}]} + result3.run_result.process_res.stdout = json.dumps(result_json_3) result3.candidate_id = 0 result3.device_id = "device1" result4 = MagicMock() - result4.run_result.process_res.stdout = "4.13" + result_json_4 = {"benchmarks": [{"real_time": 4.13}]} + result4.run_result.process_res.stdout = json.dumps(result_json_4) result4.candidate_id = 0 result4.device_id = "device2" @@ -206,7 +287,8 @@ def test_parse_model_benchmark_results(): result5.device_id = "device3" result6 = MagicMock() - result6.run_result.process_res.stdout = "3.38" + result_json_6 = {"benchmarks": [{"real_time": 3.38}]} + result6.run_result.process_res.stdout = json.dumps(result_json_6) result6.candidate_id = 3 result6.device_id = "device3" @@ -214,12 +296,13 @@ def test_parse_model_benchmark_results(): baseline_results = [result3, result4, result5] # Skip real benchmark extraction, directly use given values from above - def mock_get_mean_time(self): - return float(self.result_str) if self.result_str else None + def mock_get_mean_time_us(self): + return float(self.result_json[0]["real_time"]) if self.result_json else None # Mock IREEBenchmarkResult to return wanted benchmark times with patch( - f"{libtuner.__name__}.IREEBenchmarkResult.get_mean_time", new=mock_get_mean_time + f"{libtuner.__name__}.IREEBenchmarkResult.get_mean_time_us", + new=mock_get_mean_time_us, ): # Mock handle_error to avoid actual logging during tests with patch(f"{libtuner.__name__}.handle_error") as mock_handle_error: From 459de98f4a74159d459b68e742516334e2013748 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 11 Oct 2024 19:59:40 +0200 Subject: [PATCH 08/15] Pin (and update) actions (#268) Updates the checkout action as this uses a deprecated Node.js version and the old version will therefore be forced to run on node20. Further pins actions as suggested byt OpenSSF Scorecard, see https://github.com/ossf/scorecard/blob/main/docs/checks.md#pinned-dependencies. --- .github/workflows/ci-tuner.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-tuner.yml b/.github/workflows/ci-tuner.yml index 5de7d4182..1944caa6a 100644 --- a/.github/workflows/ci-tuner.yml +++ b/.github/workflows/ci-tuner.yml @@ -20,10 +20,10 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4.1.7 + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1 with: python-version: '3.10.12' From 3015ec7c24052fbd1826cfb4190f7f7d7d8d7c90 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 11 Oct 2024 14:41:54 -0700 Subject: [PATCH 09/15] [sharktank] Add test for sharded rotary table (#274) We should be able to validate the sharded rotary table via comparison with the unsharded version. This runs the sharded and unsharded implementations, asserting near identical results. --- .../layers/sharded_rotary_embedding_test.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 sharktank/tests/layers/sharded_rotary_embedding_test.py diff --git a/sharktank/tests/layers/sharded_rotary_embedding_test.py b/sharktank/tests/layers/sharded_rotary_embedding_test.py new file mode 100644 index 000000000..963b9b432 --- /dev/null +++ b/sharktank/tests/layers/sharded_rotary_embedding_test.py @@ -0,0 +1,56 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +import torch + +from sharktank.layers import RotaryEmbeddingLayer +from sharktank import ops +from sharktank.types import ( + ShardedTensor, + SplitPrimitiveTensor, + unbox_tensor, +) + +import unittest +from typing import List, Optional +import os + + +def test_sharded_rotary_table(): + bs = 4 + rope_dims = 16 + heads = 8 + max_seqlen = 128 + rope_freq_base = None + + # First we setup and get the default rotary embedding layer + xq = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float) + xk = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float) + default_layer = RotaryEmbeddingLayer( + rope_dimension_count=rope_dims, + max_seqlen=max_seqlen, + rope_freq_base=rope_freq_base, + ) + oq, ok = default_layer(xq=xq, xk=xk, start_index=0) + + # Then we can shard the same inputs and layer + xq = SplitPrimitiveTensor(ts=xq, shard_dim=2, shard_count=4) + xk = SplitPrimitiveTensor(ts=xk, shard_dim=2, shard_count=4) + shard_layer = RotaryEmbeddingLayer( + rope_dimension_count=rope_dims, + max_seqlen=max_seqlen, + rope_freq_base=rope_freq_base, + tensor_parallelism_size=4, + ) + sq, sk = shard_layer(xq=xq, xk=xk, start_index=0) + + # Gathering and unboxing should yield the same results + sq = ops.unshard(sq) + sk = ops.unshard(sk) + + torch.testing.assert_close(sq, oq) + torch.testing.assert_close(sk, ok) From 355761ba28a489bb33028f8f1403ed5e16302afa Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Mon, 14 Oct 2024 15:20:52 -0400 Subject: [PATCH 10/15] Add sharded paged attention test (#276) Verify that the sharded Llama paged attention block behaves in PyTorch as the unsharded variant. The fp32 accuracy seems low and this test is xfailed. The fp64 accuracy is fine. --- .../sharktank/layers/rotary_embedding.py | 2 +- sharktank/sharktank/models/llama/sharding.py | 2 +- .../sharded_paged_llama_attention_block.py | 163 ++++++++++++++++++ 3 files changed, 165 insertions(+), 2 deletions(-) create mode 100644 sharktank/tests/layers/sharded_paged_llama_attention_block.py diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index 39e8490d3..834ea349f 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -21,7 +21,7 @@ def __init__( *, rope_dimension_count: int, max_seqlen: int, - rope_freq_base: float, + rope_freq_base: Optional[float], device: Optional[torch.device] = None, use_hf: bool = False, static_tables: bool = True, diff --git a/sharktank/sharktank/models/llama/sharding.py b/sharktank/sharktank/models/llama/sharding.py index 1a98419e6..3715a3923 100644 --- a/sharktank/sharktank/models/llama/sharding.py +++ b/sharktank/sharktank/models/llama/sharding.py @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Specifications describing how blocks/layers of llama are sharded.""" +"""Specifications describing how the Llama model is sharded.""" from ...types.sharding import * from ...types import Theta diff --git a/sharktank/tests/layers/sharded_paged_llama_attention_block.py b/sharktank/tests/layers/sharded_paged_llama_attention_block.py new file mode 100644 index 000000000..c94fd44ab --- /dev/null +++ b/sharktank/tests/layers/sharded_paged_llama_attention_block.py @@ -0,0 +1,163 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +from sharktank.layers import ( + PagedLlamaAttentionBlock, + PagedKVCache, + RotaryEmbeddingLayer, +) +from sharktank.layers.testing import make_llama_attention_block_theta, make_rand_torch +from sharktank.models.llama.sharding import PagedLlamaAttentionBlockSharding +from sharktank.types import SplitPrimitiveTensor, unbox_tensor +import torch +from sharktank import ops +from copy import deepcopy +import pytest + + +class ShardedPagedLlamaAttentionBlockTest(unittest.TestCase): + """Verify that the sharded Llama paged attention block behaves in PyTorch as the + unsharded variant.""" + + def setUp(self): + torch.manual_seed(12345) + self.transformer_block_count = 13 + self.block_index = 1 + self.shard_count = 3 + self.head_count_kv = 2 * self.shard_count + self.attention_head_count = 5 * self.head_count_kv + self.attention_head_dim = 11 * 2 + self.rms_epsilon = 0.01 + self.block_seq_stride = 17 + self.cache_partition_count = 2 + self.page_count = 23 + self.embedding_length = self.attention_head_count * self.attention_head_dim + self.rope_dimension_count = self.attention_head_dim + self.block_seqlen = 7 + self.max_seqlen = self.block_seq_stride * self.block_seqlen + self.rope_freq_base = None + self.batch_size = 3 + self.start_index = 0 + + def testSmallSizedLayerFp64(self): + self.runTestSmallSizedLayer(dtype=torch.float64) + + @pytest.mark.xfail( + reason="The accuracy seems low (atol=0.0018, rtol=0.5065)", + strict=True, + raises=AssertionError, + ) + def testSmallSizedLayerFp32(self): + self.runTestSmallSizedLayer(dtype=torch.float32) + + def runTestSmallSizedLayer(self, dtype: torch.dtype): + torch.set_default_dtype(dtype) + + def make_paged_kv_cache(shard_count: int) -> PagedKVCache: + return PagedKVCache( + transformer_block_count=self.transformer_block_count, + attn_head_count=self.head_count_kv, + attn_head_dim=self.attention_head_dim, + cache_partition_count=self.cache_partition_count, + block_seq_stride=self.block_seq_stride, + dtype=dtype, + shard_count=shard_count, + ) + + cache = make_paged_kv_cache(shard_count=1) + sharded_cache = make_paged_kv_cache(shard_count=self.shard_count) + + def make_unsharded_and_sharded_equal_cache_states() -> tuple[ + list[torch.Tensor], list[SplitPrimitiveTensor] + ]: + cache_state = cache.allocate(self.page_count) + cache_state[0] = make_rand_torch(cache_state[0].shape, dtype=dtype) + sharded_cache_state = sharded_cache.shard_state(deepcopy(cache_state)) + return cache_state, sharded_cache_state + + ( + cache_state, + sharded_cache_state, + ) = make_unsharded_and_sharded_equal_cache_states() + + input_tensor = make_rand_torch( + ( + self.batch_size, + self.max_seqlen, + self.attention_head_count * self.attention_head_dim, + ), + dtype=dtype, + ) + seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view( + self.batch_size, -1 + ) + embedding_module = RotaryEmbeddingLayer( + rope_dimension_count=self.rope_dimension_count, + max_seqlen=self.max_seqlen, + rope_freq_base=self.rope_freq_base, + ) + + theta = make_llama_attention_block_theta( + head_count=self.attention_head_count, + head_count_kv=self.head_count_kv, + head_dim=self.attention_head_dim, + embedding_length=self.embedding_length, + ) + attention_block = PagedLlamaAttentionBlock( + theta=theta, + block_index=self.block_index, + cache=cache, + head_count=self.attention_head_count, + head_dim=self.attention_head_dim, + head_count_kv=self.head_count_kv, + rms_epsilon=self.rms_epsilon, + ) + expected_result = attention_block( + input_tensor, + embedding=embedding_module, + seq_block_ids=seq_block_ids, + start_index=self.start_index, + cache_state=cache_state, + ) + + sharded_input_tensor = ops.replicate(input_tensor, count=self.shard_count) + sharded_seq_block_ids = ops.replicate(seq_block_ids, count=self.shard_count) + sharded_embedding_module = RotaryEmbeddingLayer( + rope_dimension_count=self.rope_dimension_count, + max_seqlen=self.max_seqlen, + rope_freq_base=self.rope_freq_base, + tensor_parallelism_size=self.shard_count, + ) + + theta_sharding = PagedLlamaAttentionBlockSharding(shard_count=self.shard_count) + sharded_theta = ops.reshard(theta, theta_sharding) + sharded_attention_block = PagedLlamaAttentionBlock( + theta=sharded_theta, + block_index=self.block_index, + cache=sharded_cache, + head_count=self.attention_head_count, + head_dim=self.attention_head_dim, + head_count_kv=self.head_count_kv, + rms_epsilon=self.rms_epsilon, + ) + sharded_result = sharded_attention_block( + sharded_input_tensor, + embedding=sharded_embedding_module, + seq_block_ids=sharded_seq_block_ids, + start_index=self.start_index, + cache_state=sharded_cache_state, + ) + + actual_result = unbox_tensor(ops.unshard(sharded_result)) + actual_cache_state = unbox_tensor( + ops.unshard( + sharded_cache.unflatten_page_table(sharded_cache_state) + ).flatten(start_dim=1) + ) + + torch.testing.assert_close(actual_result, expected_result) + torch.testing.assert_close(actual_cache_state, cache_state[0]) From acd77e317c02803ed63e6c3a7a4dc0b033702215 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Mon, 14 Oct 2024 12:54:37 -0700 Subject: [PATCH 11/15] Add special einsum cases that lower to batch matmul (#262) --- .../kernels/mmt_block_scaled_offset_q4.py | 40 +++++++---- .../mmt_block_scaled_offset_q4_unsigned.mlir | 30 ++++++-- sharktank/sharktank/ops/default_impls.py | 71 +++++++++++++++++-- sharktank/sharktank/types/tensors.py | 2 +- 4 files changed, 120 insertions(+), 23 deletions(-) diff --git a/sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py b/sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py index 2ed171115..0c8a61f32 100644 --- a/sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py +++ b/sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py @@ -37,28 +37,33 @@ def select(self, ksel: KernelSelection): m_desc = ksel.arg_tensor(3) # Shape [N, K // BLOCK_SIZE, 1] # a arg - *batch_dims, a_m, a_k = a_desc.t.shape + *a_batch_dims, a_m, a_k = a_desc.t.shape torch._check( a_desc.t.dtype.is_floating_point, lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'a': Expected floating point (got {a_desc.t.dtype})", ) torch._check( - len(batch_dims) == 1, + len(a_batch_dims) == 1, lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'a': Expected 3d tensor (got {a_desc.t.shape})", ) # qs arg - qs_n, qs_group0, qs_bs_div_2, *rest = qs_desc.t.shape + *qs_batch_dims, qs_n, qs_group0, qs_bs_div_2 = qs_desc.t.shape torch._check( - len(rest) == 0 and (qs_group0 * qs_bs_div_2 * 2) == a_k, + ( + len(qs_batch_dims) == 0 + or len(qs_batch_dims) == 1 + and qs_batch_dims == a_batch_dims + ) + and (qs_group0 * qs_bs_div_2 * 2) == a_k, lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'qs': Incorrect shape (got {qs_desc.t.shape})", ) block_size = qs_bs_div_2 * 2 # d arg - d_n, d_group0, d_one, *rest = d_desc.t.shape + *d_batch_dims, d_n, d_group0, d_one = d_desc.t.shape torch._check( - len(rest) == 0 + d_batch_dims == qs_batch_dims and (d_group0 * block_size) == a_k and d_one == 1 and d_n == qs_n, @@ -66,9 +71,9 @@ def select(self, ksel: KernelSelection): ) # m arg - m_n, m_group0, m_one, *rest = m_desc.t.shape + *m_batch_dims, m_n, m_group0, m_one = m_desc.t.shape torch._check( - len(rest) == 0 + m_batch_dims == qs_batch_dims and (m_group0 * block_size) == a_k and m_one == 1 and m_n == qs_n, @@ -81,12 +86,17 @@ def select(self, ksel: KernelSelection): # Specialize on K, N, BS a_desc.specialize_dims(-1) - qs_desc.specialize_all_dims() - d_desc.specialize_all_dims() - m_desc.specialize_all_dims() + if len(qs_batch_dims) == 0: + qs_desc.specialize_all_dims() + d_desc.specialize_all_dims() + m_desc.specialize_all_dims() + else: + qs_desc.specialize_dims(1, 2, 3) + d_desc.specialize_dims(1, 2, 3) + m_desc.specialize_dims(1, 2, 3) # Shape batch..., m, n - c_desc = ksel.return_new_tensor(batch_dims + [a_m, d_n], dtype=a_desc.t.dtype) + c_desc = ksel.return_new_tensor(a_batch_dims + [a_m, d_n], dtype=a_desc.t.dtype) c_desc.specialize_dims(-1) def generate(self, ksel: KernelSelection, kb: KernelBuilder): @@ -99,13 +109,14 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): rank = a_tensor_type.rank k = a_tensor_type.get_dim_size(rank - 1) - n, group0, bs_i8 = qs_tensor_type.shape + *qs_batch_dims, n, group0, bs_i8 = qs_tensor_type.shape + batched_rhs = len(qs_batch_dims) == 1 bs = bs_i8 * 2 # 2 nibbles per byte. a_type_str = str(a_tensor_type.element_type) scale_type_str = str(d_tensor_type.element_type) template_file = "mmt_block_scaled_offset_q4_unsigned.mlir" - target_function_name = f"sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{bs}_{a_type_str}" + target_function_name = f"sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{bs}_{a_type_str}_{batched_rhs}" target_function = inline_template_function( kb, @@ -118,5 +129,6 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): group0=group0, a_type=a_type_str, scale_type=scale_type_str, + batched_rhs=batched_rhs, ) kb.yield_results(*call_function(target_function, *kb.arg_bindings)) diff --git a/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir b/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir index a7f3138cb..afe2928c0 100644 --- a/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir +++ b/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir @@ -12,17 +12,25 @@ !accum_type = {{accum_type}} !a_tensor_type = tensor !aexp_tensor_type = tensor +{% if batched_rhs %} +!qs_raw_tensor_type = tensor +!qs_tensor_type = tensor +!d_tensor_type = tensor +!m_tensor_type = tensor +!b_grouped_tensor_type = tensor +{% else %} !qs_raw_tensor_type = tensor<{{n}}x{{group0}}x{{bs_i8}}xi8> !qs_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!lowp_type> !d_tensor_type = tensor<{{n}}x{{group0}}x1x!scale_type> !m_tensor_type = tensor<{{n}}x{{group0}}x1x!scale_type> +!b_grouped_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!a_type> +{% endif %} !accum_tensor_type = tensor !c_tensor_type = tensor -!b_grouped_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!a_type> module { -util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_{{bs}}_{{a_type}}( +util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_{{bs}}_{{a_type}}_{{batched_rhs}}( %a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type) -> !c_tensor_type { %zero = arith.constant 0.0: !accum_type @@ -32,17 +40,31 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_ %m_dim = tensor.dim %a, %c1 : !a_tensor_type // Cast qs_raw from i8 to lowp type. +{% if batched_rhs %} + %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{ %batch0_dim } -> !qs_tensor_type{ %batch0_dim } + %b_grouped = tensor.empty(%batch0_dim) : !b_grouped_tensor_type +{% else %} %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type + %b_grouped = tensor.empty() : !b_grouped_tensor_type +{% endif %} // Dequantize. - %b_grouped = tensor.empty() : !b_grouped_tensor_type %b_grouped_dequant = linalg.generic { +{% if batched_rhs %} + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] } +{% else %} indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"] } +{% endif %} ins(%d, %m, %qs : !d_tensor_type, !m_tensor_type, !qs_tensor_type) outs(%b_grouped : !b_grouped_tensor_type) { ^bb0(%d_element: !scale_type, %m_element: !scale_type, %q_element: !lowp_type, %out: !a_type): @@ -70,7 +92,7 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_ indexing_maps = [ // d0 = b, d1 = m, d2 = n, d3 = group0 (r), d4 = block (r) affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, - affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ({% if batched_rhs %}d0,{% endif %} d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] } ins(%aexp, %b_grouped_dequant : !aexp_tensor_type, !b_grouped_tensor_type) diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index fec30fca6..ef7144bca 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -69,9 +69,56 @@ def conv2d_default( conv2d.override(Tensor, Tensor, auto_dequant=True)(conv2d_default) # Einsum -@einsum_2args.override(AllOfType(Tensor, PrimitiveTensor)) -def einsum_2args(x, y, einsum_str): - return torch.einsum(einsum_str, unbox_tensor(x), unbox_tensor(y)) +def mk_menk_men(inputs, weights): + # batch dims: m, lhs pdims: none, lhs rdims: k, rhs pdims: en, rhs rdims: k + inputs = inputs.unsqueeze(1) + weights_shape = weights.shape + weights = weights.view( + weights_shape[0], weights_shape[1] * weights_shape[2], weights_shape[3] + ) + result = matmul(inputs, weights, transpose_rhs=True) + result = result.view(weights_shape[0], weights_shape[1], weights_shape[2]) + return result + + +def mek_menk_men(inputs, weights): + # batch dims: me, lhs pdims: none, lhs rdims: k, rhs pdims: n, rhs rdims: k + inputs_shape = inputs.shape + inputs = inputs.view(inputs_shape[0] * inputs_shape[1], 1, inputs_shape[2]) + weights_shape = weights.shape + weights = weights.view( + weights_shape[0] * weights_shape[1], weights_shape[2], weights_shape[3] + ) + result = matmul(inputs, weights, transpose_rhs=True) + result = result.view(weights_shape[0], weights_shape[1], weights_shape[2]) + return result + + +def me_men_men(inputs, weights): + # batch dims: me, lhs pdims: none, lhs rdims: none, rhs pdims: n, rhs rdims: none + inputs_shape = inputs.shape + inputs = inputs.view(inputs_shape[0] * inputs_shape[1], 1, 1) + weights_shape = weights.shape + weights = weights.view(weights_shape[0] * weights_shape[1], weights_shape[2], 1) + result = matmul(inputs, weights, transpose_rhs=True) + result = result.view(weights_shape[0], weights_shape[1], weights_shape[2]) + return result + + +@einsum_2args.override(AllOfType(Tensor, PrimitiveTensor, QuantizedTensor)) +def einsum_2args(input0, input1, einsum_str): + # Special optimized einsum kernels that lower to batch matmul + if einsum_str == "mk,menk->men": + return mk_menk_men(input0, input1) + elif einsum_str == "mek,menk->men": + return mek_menk_men(input0, input1) + elif einsum_str == "me,men->men": + return me_men_men(input0, input1) + # Default non-QuantizedTensor einsum + if not isinstance(input1, QuantizedTensor): + return torch.einsum(einsum_str, unbox_tensor(x), unbox_tensor(y)) + # Fallback to other kernels + return NotImplemented # Elementwise @@ -307,7 +354,7 @@ def matmul_default(lhs, rhs, *, transpose_rhs: bool) -> Tensor: lhs = unbox_tensor(lhs) rhs = unbox_tensor(rhs) if transpose_rhs: - rhs = rhs.T + rhs = rhs.mT return torch.matmul(lhs, rhs.to(lhs.dtype)) @@ -433,3 +480,19 @@ def unsqueeze_default(tensor: Union[Tensor, PrimitiveTensor], dim: int) -> Tenso @view.override(Tensor) def view_default(tensor: Union[Tensor, PrimitiveTensor], shape: List[int]) -> Tensor: return unbox_tensor(tensor).view(*shape) + + +@view.override(QuantizedTensor) +def view_QuantizedTensor(tensor: QuantizedTensor, shape): + unpacked = tensor.unpack() + if not isinstance(unpacked, BlockScaledI4Layout): + return NotImplemented + bs = 16 + shape = list(shape) + new_d = unpacked._d.view(shape[:-1] + [shape[-1] // 32, 1]) + qs_shape = shape[:-1] + [shape[-1] // 32, 16] + new_qs = unpacked._qs.view(qs_shape) + if unpacked.m is not None: + new_m = unpacked.m.view(shape[:-1] + [shape[-1] // 32, 1]) + layout = BlockScaledI4Layout(shape=shape, d=new_d, qs=new_qs, m=new_m) + return PlanarQuantizedTensor(shape=shape, layout=layout) diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 324cc4331..70b0fbd01 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -378,7 +378,7 @@ def unsqueeze(self, dim: int) -> "AnyTensor": def view(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor": from ..ops import view - if all(isinstance(a, int) for a in args): + if all(isinstance(a, int) or isinstance(a, torch.SymInt) for a in args): shape = args else: assert len(args) == 1 From 854bea30ff8cd4bae13fb562a948236c8e6f37a7 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 14 Oct 2024 13:10:48 -0700 Subject: [PATCH 12/15] Rework RotaryEmbedding for dynamic computation (#255) Some minor changes to the rotary embedding can better support fusion and avoid using a lookup table. Depending on backend one version may provide better overall performance. --- .../sharktank/layers/rotary_embedding.py | 68 ++++++++++++------- 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index 834ea349f..18a95aba3 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from collections import namedtuple from typing import Optional, Union import torch @@ -24,7 +25,8 @@ def __init__( rope_freq_base: Optional[float], device: Optional[torch.device] = None, use_hf: bool = False, - static_tables: bool = True, + static_tables: bool = False, + use_table: bool = True, tensor_parallelism_size: int = 1, ): super().__init__() @@ -32,6 +34,8 @@ def __init__( self.rope_dimension_count = rope_dimension_count self.max_seqlen = max_seqlen self.use_hf = use_hf + self.static_tables = static_tables + self.use_table = use_table self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0 self.tensor_parallelism_size = tensor_parallelism_size @@ -44,10 +48,16 @@ def __init__( @property def rotary_embed_table(self): - if self.static_rotary_embed_table is None: + if self.use_table: + if self.static_tables: + return self.static_rotary_embed_table return self._create_rotary_embed_table() - else: - return self.static_rotary_embed_table + + if self.tensor_parallelism_size == 1: + return None + + nt = namedtuple("replicated_tensor", ["shards"]) + return nt([None] * self.tensor_parallelism_size) def forward( self, @@ -96,7 +106,7 @@ def forward_unsharded( xq: torch.Tensor, xk: torch.Tensor, start_index: int, - rotary_embed_table: torch.Tensor, + rotary_embed_table: Optional[torch.Tensor], ): # xq_, xk_ shape: bs, sl, _, dim # freqs_cis shape: max_sl, dim @@ -142,12 +152,18 @@ def create_ordering_tensor(dim): xq = xq[..., create_interleaved_tensor(xq.shape[-1])] xk = xk[..., create_interleaved_tensor(xq.shape[-1])] - xq_ = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2)) + xq_ = torch.view_as_complex(xq.unflatten(-1, (-1, 2))) + xk_ = torch.view_as_complex(xk.unflatten(-1, (-1, 2))) _, sl, _, dim = xq_.shape # Offset the table based on starting position. - freqs_cis = rotary_embed_table[start_index : start_index + sl, :] + if self.use_table: + freqs_cis = rotary_embed_table[start_index : start_index + sl, :] + else: + freqs_cis = torch.arange(start_index, start_index + sl, device=xq.device) + freqs_cis = self._compute_rotary_embed_table(freqs_cis) + freqs_cis = self._replicate(freqs_cis) + assert freqs_cis.shape[-1] == dim assert ( freqs_cis.shape[0] >= sl @@ -206,7 +222,13 @@ def compute_batch_mask( ) + start_positions.unsqueeze(1) # Broadcast lookup to [b, ...]. self.trace_tensor("rope.positions_seq", positions_seq) - freqs_cis = self.rotary_embed_table[positions_seq] + + if self.use_table: + freqs_cis = self.rotary_embed_table[positions_seq] + else: + shape = positions_seq.shape + freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten()) + freqs_cis = freqs_cis.unflatten(0, shape) # Unsqueeze a unit dim for attention heads. broadcast_freqs_cis = freqs_cis.unsqueeze(2) @@ -225,10 +247,6 @@ def apply_batched_mask( and xq.shard_count == xk.shard_count and xk.shard_dim == xq.shard_dim ) - assert ( - isinstance(self.rotary_embed_table, ReplicatedTensor) - and xq.shard_count == self.rotary_embed_table.shard_count - ) assert ( isinstance(mask, ReplicatedTensor) and mask.shard_count == xq.shard_count @@ -263,24 +281,20 @@ def apply_batched_mask_unsharded( """ # xq_, xk_ shape: bs, sl, _, dim # freqs_cis shape: max_sl, dim - xq_ = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2)) + xq_ = torch.view_as_complex(xq.unflatten(-1, (-1, 2))) + xk_ = torch.view_as_complex(xk.unflatten(-1, (-1, 2))) _, sl, _, dim = xq_.shape xq_out = torch.view_as_real(xq_ * mask).flatten(3) xk_out = torch.view_as_real(xk_ * mask).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) - def _create_rotary_embed_table( - self, - ): + def _compute_rotary_embed_table(self, t): dim = self.rope_dimension_count - max_seqlen = self.max_seqlen freqs = 1.0 / ( self.rope_freq_base - ** (torch.arange(0, dim, 2, device=self.device)[: (dim // 2)].float() / dim) + ** (torch.arange(0, dim, 2, device=t.device)[: (dim // 2)].float() / dim) ) - t = torch.arange(max_seqlen, device=freqs.device) freqs = torch.outer(t, freqs).float() freqs_cis = ( @@ -289,8 +303,16 @@ def _create_rotary_embed_table( else torch.polar(torch.ones_like(freqs), freqs) ) + return freqs_cis + + def _create_rotary_embed_table(self): + t = torch.arange(self.max_seqlen, device=self.device) + freqs_cis = self._compute_rotary_embed_table(t) + return self._replicate(freqs_cis) + + def _replicate(self, t): if self.tensor_parallelism_size > 1: # Replicate across all devices, the data is not a lot and the computation is cheap. - freqs_cis = ops.replicate(freqs_cis, self.tensor_parallelism_size) + t = ops.replicate(t, self.tensor_parallelism_size) - return freqs_cis + return t From ffee822cf1aedca1da8375f91746581dbf3fb466 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 15 Oct 2024 09:48:06 -0700 Subject: [PATCH 13/15] Refresh metadata in sharktank/setup.py. (#247) * Drop author email for consistency with shortfin (we could also pick a mailing list or more up to date contact email) * Update URL following repository rename More will be needed as we add a dep on `shortfin` and resume regular publishing of this `sharktank` package to https://pypi.org/project/sharktank/ --- sharktank/setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sharktank/setup.py b/sharktank/setup.py index 21be90019..ab6e92d33 100644 --- a/sharktank/setup.py +++ b/sharktank/setup.py @@ -78,11 +78,10 @@ def initialize_options(self): name=f"sharktank", version=f"{PACKAGE_VERSION}", author="SHARK Authors", - author_email="stella@nod.ai", description="SHARK layers and inference models for genai", long_description=README, long_description_content_type="text/markdown", - url="https://github.com/nod-ai/sharktank", + url="https://github.com/nod-ai/SHARK-Platform", license="Apache-2.0", classifiers=[ "Development Status :: 3 - Alpha", From f8fd09bd46e07e23978178df11c4268286581644 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Tue, 15 Oct 2024 19:23:17 +0200 Subject: [PATCH 14/15] [libshortfin] Bump nanobind to version 2.0.0 (#278) --- shortfin/python/CMakeLists.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/shortfin/python/CMakeLists.txt b/shortfin/python/CMakeLists.txt index a8ebfeaa0..adf9d7879 100644 --- a/shortfin/python/CMakeLists.txt +++ b/shortfin/python/CMakeLists.txt @@ -11,12 +11,10 @@ # Others. # nanobind -# Pinned to a pre 2.2.0 commit hash which includes free threaded support. -# TODO: Bump to 2.2.0 when available. FetchContent_Declare( nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git - GIT_TAG 8ce0dee7f62add575f85c0de386a9c819e4d50af # HEAD > 2.1.0 + GIT_TAG 784efa2a0358a4dc5432c74f5685ee026e20f2b6 # 2.2.0 ) FetchContent_MakeAvailable(nanobind) From 14301820924f0d76ead3250db789502da6bed2d7 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Tue, 15 Oct 2024 19:46:18 +0200 Subject: [PATCH 15/15] Add ci_windows_x64-libshortfin.yml (#269) Adds a Windows CI for shortfin. Building passes (with several warnings) but not all ctests and pytests do. As soon as we have passing tests, the CI configuration should be merged with the Linux CI into a matrix configuration. --- .../workflows/ci_windows_x64-libshortfin.yml | 95 +++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 .github/workflows/ci_windows_x64-libshortfin.yml diff --git a/.github/workflows/ci_windows_x64-libshortfin.yml b/.github/workflows/ci_windows_x64-libshortfin.yml new file mode 100644 index 000000000..c60bd816d --- /dev/null +++ b/.github/workflows/ci_windows_x64-libshortfin.yml @@ -0,0 +1,95 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: CI - shortfin - Windows + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + paths: + - '.github/workflows/ci_windows_x64-libshortfin.yml' + - 'shortfin/**' + +permissions: + contents: read + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +env: + IREE_REPO_DIR: ${{ github.workspace }}/iree + LIBSHORTFIN_DIR: ${{ github.workspace }}/shortfin/ + +jobs: + build-and-test: + name: Build and test + runs-on: windows-2022 + + steps: + - name: Configure MSVC + uses: ilammy/msvc-dev-cmd@0b201ec74fa43914dc39ae48a89fd1d8cb592756 # v1.13.0 + + - name: Checkout repository + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + with: + submodules: false + + - name: Checkout IREE repo + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + with: + repository: iree-org/iree + path: ${{ env.IREE_REPO_DIR }} + submodules: false + ref: candidate-20240904.1006 + + - name: Initalize IREE submodules + working-directory: ${{ env.IREE_REPO_DIR }} + run : | + git submodule update --init --depth 1 -- third_party/benchmark + git submodule update --init --depth 1 -- third_party/cpuinfo/ + git submodule update --init --depth 1 -- third_party/flatcc + git submodule update --init --depth 1 -- third_party/googletest + git submodule update --init --depth 1 -- third_party/hip-build-deps/ + + - name: Setup Python + uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1 + with: + python-version: "3.12" + cache: "pip" + - name: Install Python packages + working-directory: ${{ env.LIBSHORTFIN_DIR }} + run: | + pip install -r requirements-tests.txt + pip install -r requirements-iree-compiler.txt + pip freeze + + - name: Build shortfin (full) + working-directory: ${{ env.LIBSHORTFIN_DIR }} + shell: bash + run: | + mkdir build + cmake -GNinja \ + -S. \ + -Bbuild \ + -DSHORTFIN_BUNDLE_DEPS=ON \ + -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ + -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON + cmake --build build --target all + pip install -v -e build/ + + - name: Test shortfin (full) + working-directory: ${{ env.LIBSHORTFIN_DIR }} + run: | + ctest --timeout 30 --output-on-failure --test-dir build + pytest -s