diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index b1ef57090..a99219022 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -88,7 +88,7 @@ def conv2d_all_split( input.is_replicated or input.shard_dim == 1 ), "Only sharding of input channel dimension is supported" assert ( - weight.shard_dim == 0 and bias.shard_dim == 0 + bias is None or weight.shard_dim == 0 and bias.shard_dim == 0 ), "Only sharding of output channel dimension is supported" # TODO: allow for implementation where we don't all-gather, but gather @@ -146,7 +146,7 @@ def conv2d_replicated_input_split_weight_and_bias( assert input.shard_count == weight.shard_count assert bias is None or weight.shard_count == bias.shard_count assert ( - weight.shard_dim == 0 and bias.shard_dim == 0 + bias is None or weight.shard_dim == 0 and bias.shard_dim == 0 ), "Only sharding of output channel dimension is supported" assert groups == 1 @@ -189,7 +189,8 @@ def conv2d_split_weight_and_bias( accum_dtype, ) -> SplitPrimitiveTensor: assert accum_dtype is None, "accum_dtype not supported" - assert weight.shard_count == bias.shard_count + if bias is not None: + assert weight.shard_count == bias.shard_count # Output channels dimension is split. if weight.shard_dim == 0 and groups == 1: diff --git a/sharktank/tests/layers/conftest.py b/sharktank/tests/layers/conftest.py new file mode 100644 index 000000000..1c88ceb31 --- /dev/null +++ b/sharktank/tests/layers/conftest.py @@ -0,0 +1,56 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest +from pathlib import Path +from typing import Optional + + +def pytest_addoption(parser): + parser.addoption( + "--mlir", + type=Path, + default=None, + help="Path to exported MLIR program. If not specified a temporary file will be used.", + ) + parser.addoption( + "--module", + type=Path, + default=None, + help="Path to exported IREE module. If not specified a temporary file will be used.", + ) + parser.addoption( + "--parameters", + type=Path, + default=None, + help="Exported model parameters. If not specified a temporary file will be used.", + ) + parser.addoption( + "--caching", + action="store_true", + default=False, + help="Load cached results if present instead of recomputing.", + ) + + +@pytest.fixture(scope="session") +def mlir_path(pytestconfig: pytest.Config) -> Optional[Path]: + return pytestconfig.getoption("mlir") + + +@pytest.fixture(scope="session") +def module_path(pytestconfig: pytest.Config) -> Optional[Path]: + return pytestconfig.getoption("module") + + +@pytest.fixture(scope="session") +def parameters_path(pytestconfig: pytest.Config) -> Optional[Path]: + return pytestconfig.getoption("parameters") + + +@pytest.fixture(scope="session") +def caching(pytestconfig: pytest.Config) -> Optional[Path]: + return pytestconfig.getoption("caching") diff --git a/sharktank/tests/layers/sharded_conv2d_with_iree_test.py b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py new file mode 100644 index 000000000..e74718ea6 --- /dev/null +++ b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py @@ -0,0 +1,212 @@ +import unittest + +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from pathlib import Path +import tempfile +import torch +from shark_turbine import aot +from sharktank.models.punet.layers import Conv2DLayer +from sharktank import ops +from sharktank.types import ( + Dataset, + DefaultPrimitiveTensor, + Theta, + ShardedTensor, + SplitPrimitiveTensor, + unbox_tensor, +) +from sharktank.types.sharding import Conv2DSplitOutputChannelSharding +import iree.runtime +from typing import List, Optional +import os + +vm_context: iree.runtime.VmContext = None + + +def get_compiler_args(target_device_kind: str, shard_count: int) -> List[str]: + result = [ + f"--iree-hal-target-device={target_device_kind}[{i}]" + for i in range(shard_count) + ] + return result + + +def compile_iree_module( + export_output: aot.ExportOutput, module_path: str, shard_count: int +): + export_output.session.set_flags( + *get_compiler_args(target_device_kind="llvm-cpu", shard_count=shard_count) + ) + export_output.compile(save_to=module_path, target_backends=None) + + +# TODO: improve IREE's Python API to be more concise in a multi-device context. +# This run function should be way shorter. +def run_iree_module( + sharded_input_image: ShardedTensor, + module_path: str, + parameters_path: str, +) -> ShardedTensor: + shard_count = sharded_input_image.shard_count + hal_driver = iree.runtime.get_driver("local-task") + vm_instance = iree.runtime.VmInstance() + available_devices = hal_driver.query_available_devices() + # Use the same actual device for all devices. + devices = [ + hal_driver.create_device(available_devices[0]) for _ in range(shard_count) + ] + hal_module = iree.runtime.create_hal_module(instance=vm_instance, devices=devices) + params_path = Path(parameters_path) + # TODO: make IREE able to load the parameters from the top parameter file + # without having to specify the parameter file for each shard separately. + parameter_index = iree.runtime.ParameterIndex() + for i in range(shard_count): + parameter_index.load( + file_path=str( + Path(params_path).with_suffix(f".rank{i}{params_path.suffix}") + ) + ) + parameter_provider = parameter_index.create_provider(scope="model") + parameters_module = iree.runtime.create_io_parameters_module( + vm_instance, parameter_provider + ) + + vm_module = iree.runtime.VmModule.mmap(vm_instance, str(module_path)) + + # The context needs to be destroyed after the buffers, although + # it is not associate with them on the API level. + global vm_context + vm_context = iree.runtime.VmContext( + instance=vm_instance, modules=(hal_module, parameters_module, vm_module) + ) + module_input_args = [ + iree.runtime.asdevicearray( + devices[i], sharded_input_image.shards[i].as_torch().to("cpu").numpy() + ) + for i in range(shard_count) + ] + + vm_function = vm_module.lookup_function("main") + invoker = iree.runtime.FunctionInvoker( + vm_context=vm_context, + # TODO: rework iree.runtime.FunctionInvoker interface for multiple devices. + # This works, but does not look right. + device=devices[0], + vm_function=vm_function, + ) + results = invoker(*module_input_args) + shards = [torch.tensor(tensor.to_host()) for tensor in results] + return SplitPrimitiveTensor(ts=shards, shard_dim=1) + + +def run_test_sharded_conv2d_with_iree( + mlir_path: Path, module_path: Path, parameters_path: Path, caching: bool +): + torch.set_default_dtype(torch.float32) + torch.manual_seed(123456) + batches = 2 + in_channels = 6 + out_channels = 8 + height = 11 + width = 13 + kernel_height = 5 + kernel_width = 5 + shard_count = 2 + unsharded_theta = Theta( + { + "weight": DefaultPrimitiveTensor( + data=torch.rand( + out_channels, + in_channels, + kernel_height, + kernel_width, + ) + ), + } + ) + unsharded_theta.rename_tensors_to_paths() + + if not caching or not os.path.exists(parameters_path): + sharding_spec = Conv2DSplitOutputChannelSharding(shard_count=shard_count) + sharded_theta = ops.reshard(unsharded_theta, sharding_spec) + + # Roundtrip the dataset, which anchors the tensors as parameters to be loaded + # vs constants to be frozen (TODO: This is a bit wonky). + sharded_dataset = Dataset({}, sharded_theta) + sharded_dataset.save(parameters_path) + + sharded_dataset = Dataset.load(parameters_path) + + input_image = torch.rand( + batches, + in_channels, + height, + width, + ) + + sharded_torch_module = Conv2DLayer(sharded_dataset.root_theta, padding=(0, 0)) + sharded_input_image = ops.reshard_split(input_image, dim=1, count=shard_count) + expected_result = sharded_torch_module(sharded_input_image) + + if not caching or not os.path.exists(module_path): + exported_module = aot.export( + sharded_torch_module, + args=(sharded_input_image,), + ) + exported_module.save_mlir(mlir_path) + + compile_iree_module( + export_output=exported_module, + module_path=module_path, + shard_count=shard_count, + ) + + actual_result = run_iree_module( + sharded_input_image=sharded_input_image, + module_path=module_path, + parameters_path=parameters_path, + ) + assert len(actual_result.shards) == len(expected_result.shards) + assert actual_result.shard_dim == expected_result.shard_dim + # TODO: reenable this check once numerical issues are resolved. + # See https://github.com/iree-org/iree/issues/18283 + # for actual_shard, expected_shard in zip( + # actual_result.shards, expected_result.shards + # ): + # torch.testing.assert_close( + # unbox_tensor(actual_shard), unbox_tensor(expected_shard) + # ) + + +def test_sharded_conv2d_with_iree( + mlir_path: Optional[Path], + module_path: Optional[Path], + parameters_path: Optional[Path], + caching: bool, +): + """Test sharding, exporting and running with IREE a 2D convolution layer.""" + + with tempfile.TemporaryDirectory( + # TODO: verify hypothesis and remove ignore_cleanup_errors=True after a fix. + # torch.export.export is spawning some processes that don't exit when the + # function returns, this causes some objects to not get destroyed, which + # in turn holds files params.rank0.irpa and params.rank1.irpa open. + ignore_cleanup_errors=True + ) as tmp_dir: + mlir_path = Path(tmp_dir) / "model.mlir" if mlir_path is None else mlir_path + module_path = ( + Path(tmp_dir) / "module.vmfb" if module_path is None else module_path + ) + parameters_path = ( + Path(tmp_dir) / "params.irpa" + if parameters_path is None + else parameters_path + ) + run_test_sharded_conv2d_with_iree( + mlir_path, module_path, parameters_path, caching + ) diff --git a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py index ee3c42926..d4c1126e2 100644 --- a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py +++ b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py @@ -40,6 +40,8 @@ def compile_iree_module( export_output.compile(save_to=module_path, target_backends=None) +# TODO: improve IREE's Python API to be more concise in a multi-device context. +# This run function should be way shorter. def run_iree_module( sharded_input_image: ShardedTensor, sharded_input_time_emb: ShardedTensor, @@ -163,6 +165,12 @@ def run_test_sharded_resnet_block_with_iree( temb_channels=input_time_emb_shape[1], time_embedding_norm="default", ) + input_time_emb = torch.load( + "/home/bpetkant/ws/sharktank/experiments/sharding/punet/resnet_block/conv2-output.pt" + ) + input_image = torch.load( + "/home/bpetkant/ws/sharktank/experiments/sharding/punet/resnet_block/input_tensor.pt" + ) sharded_input_image = ops.reshard_split(input_image, dim=1, count=shard_count) sharded_input_time_emb = ops.replicate(input_time_emb, count=shard_count) expected_result = sharded_resnet_block(sharded_input_image, sharded_input_time_emb) @@ -206,6 +214,7 @@ def run_test_sharded_resnet_block_with_iree( ) assert len(actual_result.shards) == len(expected_result.shards) # TODO: reenable this check once numerical issues are resolved. + # See https://github.com/iree-org/iree/issues/18283 # for actual_shard, expected_shard in zip( # actual_result.shards, expected_result.shards # ):