From e486ad4a377c0d9c03a4751a80aa71fdd3a0bdb5 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 11 Dec 2024 18:07:36 -0800 Subject: [PATCH] fix tests --- .../examples/sharding/export_ffn_net.py | 6 ++- sharktank/sharktank/models/llama/llama.py | 7 +++- sharktank/sharktank/ops/sharded_impls.py | 42 +++++++++++++++---- sharktank/sharktank/ops/signatures.py | 8 ++-- sharktank/sharktank/types/tensors.py | 12 +++--- sharktank/tests/layers/kv_cache_test.py | 11 +++-- .../layers/sharded_paged_kv_cache_test.py | 16 ++++--- .../layers/sharded_rotary_embedding_test.py | 6 ++- .../tests/models/llama/sharded_llama_test.py | 2 +- sharktank/tests/ops/sharded_test.py | 6 +-- sharktank/tests/types/tensors_test.py | 4 +- 11 files changed, 83 insertions(+), 37 deletions(-) diff --git a/sharktank/sharktank/examples/sharding/export_ffn_net.py b/sharktank/sharktank/examples/sharding/export_ffn_net.py index f261a92e1..7d1b1a2be 100644 --- a/sharktank/sharktank/examples/sharding/export_ffn_net.py +++ b/sharktank/sharktank/examples/sharding/export_ffn_net.py @@ -17,6 +17,8 @@ import torch import torch.nn as nn +from iree.turbine.aot import DeviceAffinity + from ...layers import * from ... import ops from ...types import * @@ -50,7 +52,9 @@ def forward(self, x: torch.Tensor): ffn_gate_weight = self.theta.tensor("ffn_gate", "weight") ffn_up_weight = self.theta.tensor("ffn_up", "weight") ffn_down_weight = self.theta.tensor("ffn_down", "weight") - x = ops.replicate(x, count=ffn_gate_weight.shard_count) + + devices = [DeviceAffinity(i) for i in range(ffn_down_weight.shard_count)] + x = ops.replicate(x, devices=devices) ffn_gate = ops.elementwise( torch.nn.functional.silu, ops.linear(x, ffn_gate_weight) ) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index fa481328d..367082f48 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -9,6 +9,8 @@ from dataclasses import dataclass from typing import Union +from iree.turbine.aot import DeviceAffinity + import torch import torch.nn as nn import torch.nn.functional as F @@ -62,7 +64,7 @@ class PagedLlamaModelV1(BaseCausalLMModel): unsharded result or chain it with other tensor-parallel operations. """ - def __init__(self, theta: Theta, config: LlamaModelConfig, devices: list): + def __init__(self, theta: Theta, config: LlamaModelConfig, devices: list = None): hp = config.hp super().__init__( theta, @@ -80,6 +82,9 @@ def __init__(self, theta: Theta, config: LlamaModelConfig, devices: list): self.use_hf = config.use_hf self.attention_kernel = config.attention_kernel + if devices is None: + devices = [DeviceAffinity(i) for i in range(config.tensor_parallelism_size)] + self.add_module( "token_embedding", TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype), diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index 45505648a..00ebd0138 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -923,38 +923,62 @@ def repeat_replicated(input: ReplicatedTensor, *sizes: List[int]) -> ReplicatedT @replicate.override(ReplicatedTensor) -def replicate_replicated(input: ReplicatedTensor, *, devices: list) -> ReplicatedTensor: - if input.shard_count != len(devices): +def replicate_replicated( + input: ReplicatedTensor, *, devices: list, count: int +) -> ReplicatedTensor: + if devices is not None and input.shard_count != len(devices): raise ValueError( f"Number of shards not equal ({input.shard_count} != {len(devices)})" ) + if count is not None and input.shard_count != count: + raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") return input @replicate.override(SplitPrimitiveTensor) -def replicate_split(input: SplitPrimitiveTensor, *, devices: list) -> ReplicatedTensor: - if input.shard_count != len(devices): +def replicate_split( + input: SplitPrimitiveTensor, *, devices: list, count: int +) -> ReplicatedTensor: + if devices is not None and input.shard_count != len(devices): raise ValueError( f"Number of shards not equal ({input.shard_count} != {len(devices)})" ) + if count is not None and input.shard_count != count: + raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") return all_gather(input) @replicate.override(UnreducedTensor) -def replicate_unreduced(input: UnreducedTensor, *, devices: list) -> ReplicatedTensor: - if input.shard_count != len(devices): +def replicate_unreduced( + input: UnreducedTensor, *, devices: list, count: int +) -> ReplicatedTensor: + if devices is not None and input.shard_count != len(devices): raise ValueError( f"Number of shards not equal ({input.shard_count} != {len(devices)})" ) + if count is not None and input.shard_count != count: + raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") return all_reduce(input) @replicate.override(Tensor) -def replicate_unsharded(input, *, devices: list) -> ReplicatedTensor: +def replicate_unsharded(input, *, devices: list, count: int) -> ReplicatedTensor: torch_input = unbox_tensor(input) # If we have a torch input replicating we can assume we need to transfer: - torch_inputs = [transfer_to_logical_device(torch_input, d.ordinal) for d in devices] - return ReplicatedTensor(ts=torch_inputs, devices=devices) + if devices is not None: + torch_inputs = [ + transfer_to_logical_device(torch_input, d.ordinal) for d in devices + ] + return ReplicatedTensor(ts=torch_inputs, devices=devices) + + if count is not None: + devices = [DeviceAffinity(i) for i in range(count)] + torch_inputs = [ + transfer_to_logical_device(torch_input, i) for i in range(count) + ] + return ReplicatedTensor(ts=torch_inputs, devices=devices) + + raise ValueError(f"Devices or count is required") @reshape.override(SplitPrimitiveTensor) diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index b0504fdfa..c5dadb698 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -796,7 +796,9 @@ def _repeat_trampoline( @overridable -def replicate(input: AnyTensor, devices: list) -> ShardedTensor: +def replicate( + input: AnyTensor, devices: list = None, count: int = None +) -> ShardedTensor: """Replicate across devices. Possibly reshards if required.""" @@ -805,11 +807,11 @@ def replicate(input: AnyTensor, devices: list) -> ShardedTensor: @replicate.trampoline def _replicate_trampoline( - d: SignatureDispatcher, input: AnyTensor, devices: list + d: SignatureDispatcher, input: AnyTensor, devices: list = None, count: int = None ) -> ShardedTensor: tensors = (input,) for override in d.find_overrides(tensors): - result = override(input, devices=devices) + result = override(input, devices=devices, count=count) if result is not NotImplemented: return override, result else: diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index b2638d451..c4fad275f 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -813,9 +813,6 @@ def __init__( self._devices: tuple[DeviceAffinity] = tuple(devices) - for i, t in enumerate(ts): - DeviceTensorTrait(i).set(t) - def assign_affinities(self, affinities): assert len(affinities) == len(self._devices) self._devices = tuple(affinities) @@ -893,6 +890,7 @@ def create( t_name = str(i) try: t = raw_tensors[t_name] + DeviceTensorTrait(i).set(t) ts.append(t) except KeyError as e: raise IOError( @@ -996,7 +994,7 @@ def __init__( number of pieces. """ - assert shard_count is None or not isinstance(ts, torch.Tensor) + assert shard_count is None or isinstance(ts, torch.Tensor) shard_count = shard_count if shard_count is not None else len(ts) if devices is None: @@ -1169,9 +1167,6 @@ def __init__( self._devices: tuple[DeviceAffinity] = tuple(devices) - for d, t in zip(devices, ts): - DeviceTensorTrait(d.ordinal, d.queues).set(t) - def assign_affinities(self, affinities): assert len(affinities) == len(self._devices) self._devices = tuple(affinities) @@ -1238,6 +1233,9 @@ def create( nt = deepcopy(t) ts.append(nt) + for i, t in enumerate(ts): + DeviceTensorTrait(i).set(t) + except KeyError as e: raise IOError(f"Missing component tensor '' in {raw_tensors.keys()}") from e return cls(name=name, ts=ts) diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py index 65b42c986..228535ee8 100644 --- a/sharktank/tests/layers/kv_cache_test.py +++ b/sharktank/tests/layers/kv_cache_test.py @@ -8,6 +8,7 @@ import torch +from iree.turbine.aot import DeviceAffinity from sharktank.ops import replicate, reshard_split, unshard from sharktank.layers import * from sharktank.types import * @@ -148,6 +149,8 @@ def test_sharded_direct(): write_seq_length = seq_length - 5 + devices = [DeviceAffinity(i) for i in range(shard_count)] + # Write a prefill in: write_ones = reshard_split( torch.full( @@ -204,7 +207,7 @@ def test_sharded_direct(): ) write_pos = replicate( - torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count + torch.full((bs,), write_seq_length, dtype=torch.int64), devices ) cache.write_timestep( allocation, @@ -379,11 +382,13 @@ def test_sharded_paged(): device=None, ) + devices = [DeviceAffinity(i) for i in range(shard_count)] + write_seq_length = seq_length - 4 page_count = bs * seq_length // block_seq_stride page_ids = torch.arange(page_count, dtype=torch.int64) page_ids = page_ids.view(bs, seq_length // block_seq_stride) - page_ids = replicate(page_ids, shard_count) + page_ids = replicate(page_ids, devices=devices) write_page_ids = page_ids[:, : write_seq_length // block_seq_stride] allocation = cache.allocate(page_count=page_count) @@ -458,7 +463,7 @@ def test_sharded_paged(): ) write_pos = replicate( - torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count + torch.full((bs,), write_seq_length, dtype=torch.int64), devices ) cache.write_timestep( diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py index d7b6a0b33..c54b20459 100644 --- a/sharktank/tests/layers/sharded_paged_kv_cache_test.py +++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py @@ -5,13 +5,16 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import unittest -from sharktank.layers import PagedKVCache import torch -from sharktank.utils import iterables_equal from copy import deepcopy from typing import List, Tuple + +from iree.turbine.aot import DeviceAffinity + from sharktank import ops +from sharktank.layers import PagedKVCache from sharktank.types import SplitPrimitiveTensor +from sharktank.utils import iterables_equal class ShardedPagedKVCacheTest(unittest.TestCase): @@ -31,6 +34,7 @@ def setUp(self): self.batch_size = 11 self.block_seq_len = 2 self.max_seq_len = self.block_seq_len * self.block_seq_stride + self.devices = [DeviceAffinity(i) for i in range(self.shard_count)] self.cache = PagedKVCache( transformer_block_count=self.transformer_block_count, @@ -131,7 +135,7 @@ def testRead(self): for t in read_into_partitions_snapshot ] ) - sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) + sharded_page_ids = ops.replicate(page_ids, devices=self.devices) self.sharded_cache.read( state=sharded_cache_state, read_into_partitions=sharded_read_into_partitions, @@ -179,8 +183,8 @@ def testWriteTimestep(self): for t in cache_partitions ] ) - sharded_seq_positions = ops.replicate(seq_positions, count=self.shard_count) - sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) + sharded_seq_positions = ops.replicate(seq_positions, devices=self.devices) + sharded_page_ids = ops.replicate(page_ids, devices=self.devices) self.sharded_cache.write_timestep( state=sharded_cache_state, cache_partitions=sharded_cache_partitions, @@ -224,7 +228,7 @@ def testWrite(self): for t in cache_partitions ] ) - sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) + sharded_page_ids = ops.replicate(page_ids, devices=self.devices) self.sharded_cache.write( state=sharded_cache_state, cache_partitions=sharded_cache_partitions, diff --git a/sharktank/tests/layers/sharded_rotary_embedding_test.py b/sharktank/tests/layers/sharded_rotary_embedding_test.py index f24b8313a..2d2d95691 100644 --- a/sharktank/tests/layers/sharded_rotary_embedding_test.py +++ b/sharktank/tests/layers/sharded_rotary_embedding_test.py @@ -7,6 +7,8 @@ import torch +from iree.turbine.aot import DeviceAffinity + from sharktank.layers import RotaryEmbeddingLayer from sharktank import ops from sharktank.types import ( @@ -27,6 +29,8 @@ def test_sharded_rotary_table(): max_seqlen = 128 rope_freq_base = None + devices = [DeviceAffinity(i) for i in range(4)] + # First we setup and get the default rotary embedding layer xq = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float) xk = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float) @@ -45,7 +49,7 @@ def test_sharded_rotary_table(): rope_dimension_count=rope_dims, max_seqlen=max_seqlen, rope_freq_base=rope_freq_base, - tensor_parallelism_size=4, + devices=devices, ) sq = shard_layer(xt=xq, start_index=0) sk = shard_layer(xt=xk, start_index=0) diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py index 386061731..b1d0aa745 100644 --- a/sharktank/tests/models/llama/sharded_llama_test.py +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -6,7 +6,7 @@ import unittest import pytest -from typing import Any, List, Tuple, OrderedDict +from typing import Any, Tuple, OrderedDict from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 import sharktank.ops as ops from sharktank.types import unbox_tensor, Dataset, UnreducedTensor, SplitPrimitiveTensor diff --git a/sharktank/tests/ops/sharded_test.py b/sharktank/tests/ops/sharded_test.py index e5efaa948..adfb38359 100644 --- a/sharktank/tests/ops/sharded_test.py +++ b/sharktank/tests/ops/sharded_test.py @@ -630,7 +630,7 @@ def testAttentionShardedBatchMask(self): q_s = SplitPrimitiveTensor(shard_dim=0, ts=q.split(1, dim=0)) k_s = SplitPrimitiveTensor(shard_dim=0, ts=k.split(1, dim=0)) v_s = SplitPrimitiveTensor(shard_dim=0, ts=v.split(1, dim=0)) - a_s = ReplicatedTensor(ts=a, shard_count=4) + a_s = ReplicatedTensor(ts=[a] * 4) expected_result = ops.scaled_dot_product_attention( q, k, v, a=a, is_causal=False @@ -754,7 +754,7 @@ def testShardedLhsReplcatedRhs(self): expected_result = torch.matmul(a, b) shard_count = 3 a_sharded = SplitPrimitiveTensor(ts=a, shard_dim=1, shard_count=shard_count) - b_sharded = ReplicatedTensor(ts=b, shard_count=shard_count) + b_sharded = ReplicatedTensor(ts=[b] * shard_count) res_sharded = ops.matmul(a_sharded, b_sharded) assert isinstance(res_sharded, SplitPrimitiveTensor) assert res_sharded.shard_dim == 1 @@ -837,7 +837,7 @@ def testReplicateUnsharded(self): tensor = torch.rand(4, 5, dtype=torch.float32) shard_count = 3 actual_result = ops.replicate(tensor, count=shard_count) - expected_result = ReplicatedTensor(ts=tensor, shard_count=shard_count) + expected_result = ReplicatedTensor(ts=[tensor] * shard_count) assert expected_result.is_deep_equal(actual_result) # Test that is a copy. diff --git a/sharktank/tests/types/tensors_test.py b/sharktank/tests/types/tensors_test.py index 4af4a513f..37898a7c5 100644 --- a/sharktank/tests/types/tensors_test.py +++ b/sharktank/tests/types/tensors_test.py @@ -106,7 +106,7 @@ def testUnreducedTensorSaveLoad(self): def testReplicatedTensorExtractSlice(self): tensor = torch.rand([2, 3, 4], dtype=torch.float32) - replicated_tensor = ReplicatedTensor(ts=tensor, shard_count=3) + replicated_tensor = ReplicatedTensor(ts=[tensor] * 3) s = [slice(1, 2), slice(0, 3, 2), None] expected_result = tensor[s] replicated_sliced_tensor = replicated_tensor[s] @@ -116,7 +116,7 @@ def testReplicatedTensorExtractSlice(self): def testReplicatedTensorExtractElement(self): tensor = torch.rand([2, 3, 4], dtype=torch.float32) - replicated_tensor = ReplicatedTensor(ts=tensor, shard_count=3) + replicated_tensor = ReplicatedTensor(ts=[tensor] * 3) idx = ( 1, 2,