From 9893596bb8dce95d3e468ddbde69bab106472558 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 11 Dec 2024 16:32:27 -0800 Subject: [PATCH 1/2] Expanded sharded support for alternative sharding mechanisms Single-logical-multi-physical sharding allows tensor access between different devices and tighter synchronization on execution. This means that sharding needs to support more than differing device ordinals but also configre multiple queues for the same device. Sharded tensor types are reworked to support tracking both the supported device AND the queue it is enqueued on. To support this each sharded tensor now tracks the DeviceAffinity it is associated with, along with reassigning affinities post construction. This allows pre-sharded models to have their affinities updated with an alternative transfer mechanism. If device affinity is not specified the default arrangement assumes separate device ordinals for each shard. --- .../sharktank/examples/export_paged_llm_v1.py | 82 ++++--- sharktank/sharktank/examples/paged_llm_v1.py | 23 +- sharktank/sharktank/layers/kv_cache.py | 4 +- .../sharktank/layers/rotary_embedding.py | 19 +- sharktank/sharktank/models/llama/llama.py | 12 +- sharktank/sharktank/ops/sharded_impls.py | 201 +++++++++++------- sharktank/sharktank/ops/signatures.py | 6 +- sharktank/sharktank/types/tensors.py | 92 ++++++-- sharktank/sharktank/types/theta.py | 2 - sharktank/sharktank/utils/cli.py | 7 +- 10 files changed, 299 insertions(+), 149 deletions(-) diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 900c1a9ae..945c418c3 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -18,11 +18,12 @@ # TODO: Should be using a base class with the protocol supported. from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 -from ..models.llama.sharding import shard_theta from ..models.mixtral.mixtral import * from ..models.grok.grok import * from .. import ops +from typing import Union, Sequence + def main(): from ..utils import cli @@ -55,6 +56,11 @@ def main(): help="Enables strictness during export", action="store_true", ) + parser.add_argument( + "--use-queue-affinities", + help="Enables queue affinities for multidevice", + action="store_true", + ) cli.add_quantization_options(parser) cli.add_model_options(parser) @@ -74,18 +80,44 @@ def main(): tensor_parallelism_size=tensor_parallelism_size, use_hf=False, static_tables=False, # Rely on the compiler for hoisting tables. - kv_cache_type="direct" if args.bs == [1] else "paged", + kv_cache_type="paged" if args.bs == [1] else "paged", attention_kernel=args.attention_kernel, ) llama_config.fake_quant = args.fake_quant + def setup_queue_affinities(sharding): + return [DeviceAffinity(0, [str(i)]) for i in range(sharding)] + + def assign_affinities(theta, affinities): + def transform(tensors: Union[None, InferenceTensor, Sequence[InferenceTensor]]): + if tensors is None: + return tensors + if isinstance(tensors, Sequence): + return [transform(t) for t in tensors] + if isinstance(tensors, ShardedTensor): + tensors.assign_affinities(affinities) + return tensors + if isinstance(tensors, InferenceTensor): + return tensors + + raise ValueError("Unknown device for reassigned affinities") + + return theta.transform(transform) + + affinities = [ + DeviceAffinity(i) for i in range(llama_config.tensor_parallelism_size) + ] + if args.use_queue_affinities: + affinities = setup_queue_affinities(llama_config.tensor_parallelism_size) + dataset.root_theta = assign_affinities(dataset.root_theta, affinities) + if llama_config.hp.expert_count: if llama_config.hp.model_arch == "grok": model = PagedGrokModelV1(dataset.root_theta, llama_config) else: model = PagedMixtralModelV1(dataset.root_theta, llama_config) else: - model = PagedLlamaModelV1(dataset.root_theta, llama_config) + model = PagedLlamaModelV1(dataset.root_theta, llama_config, affinities) def generate_params_json( hp: LlamaHParams, prefill_bs: list[int], decode_bs: list[int] @@ -121,7 +153,7 @@ def generate_params_json( fxb = FxProgramsBuilder(model) - def setup_cache(model, shard_count): + def setup_cache(model, affinities): if model.config.kv_cache_type == "paged": cache_state = model.cache.allocate( page_count=hp.context_length // llama_config.block_seq_stride @@ -143,24 +175,21 @@ def setup_cache(model, shard_count): ] for i in range(llama_config.tensor_parallelism_size): - arg_affinities[i] = DeviceAffinity(str(i)) + arg_affinities[i] = affinities[i] return unpacked, shard_dim, dynamic_shapes, arg_affinities elif model.config.kv_cache_type == "direct": - cache_state = model.cache.allocate(bs=1) - # Direct cache dimensions: - # 2 * transformer_block_count of... - # [bs, seq_length, attn_head_count, attn_head_dim] - dynamic_shapes = [None] - arg_affinities = {} - shard_dim = None - return torch.stack(cache_state), shard_dim, dynamic_shapes, arg_affinities + raise NotImplementedError(f"Direct cache is not currently functional") + else: raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}") def repack_cache(cache, shard_dim): - return [SplitPrimitiveTensor(ts=c, shard_dim=shard_dim) for c in cache] + return [ + SplitPrimitiveTensor(ts=c, shard_dim=shard_dim, devices=affinities) + for c in cache + ] def generate_batch_prefill(bs: int): # torch.export.Dim would make min at least 2 @@ -177,7 +206,7 @@ def generate_batch_prefill(bs: int): seq_lens = torch.empty(bs, dtype=torch.int64) cache, cache_shard_dim, cache_dynamic_shapes, arg_affinities = setup_cache( - model, llama_config.tensor_parallelism_size + model, affinities ) if llama_config.tensor_parallelism_size > 1: @@ -219,9 +248,9 @@ def _(model, tokens, seq_lens, seq_block_ids, cs): if llama_config.tensor_parallelism_size != 1: shard_count = llama_config.tensor_parallelism_size - tokens = ops.replicate(tokens, count=shard_count) - attention_mask = ops.replicate(attention_mask, count=shard_count) - seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) + tokens = ops.replicate(tokens, devices=affinities) + attention_mask = ops.replicate(attention_mask, devices=affinities) + seq_block_ids = ops.replicate(seq_block_ids, devices=affinities) cache_tensors = repack_cache(cs, cache_shard_dim) @@ -256,7 +285,7 @@ def generate_batch_decode(bs: int): cache_shard_dim, cache_dynamic_shapes, arg_affinities, - ) = setup_cache(model, llama_config.tensor_parallelism_size) + ) = setup_cache(model, affinities) if llama_config.tensor_parallelism_size > 1: # We need to offset the indices for the cache @@ -264,7 +293,7 @@ def generate_batch_decode(bs: int): # Inputs have default affinity 0 for i in range(4): - arg_affinities[i] = DeviceAffinity("0") + arg_affinities[i] = affinities[0] dynamic_shapes = { "tokens": {}, @@ -303,12 +332,10 @@ def _( attention_mask = model.decode_attention_mask(input_mask) if llama_config.tensor_parallelism_size != 1: - shard_count = llama_config.tensor_parallelism_size - - tokens = ops.replicate(tokens, count=shard_count) - attention_mask = ops.replicate(attention_mask, count=shard_count) - start_positions = ops.replicate(start_positions, count=shard_count) - seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) + tokens = ops.replicate(tokens, devices=affinities) + attention_mask = ops.replicate(attention_mask, devices=affinities) + start_positions = ops.replicate(start_positions, devices=affinities) + seq_block_ids = ops.replicate(seq_block_ids, devices=affinities) cache_state = repack_cache(cache_state, cache_shard_dim) @@ -327,7 +354,8 @@ def _( bsizes = [] for bs in args.bs: - generate_batch_prefill(bs) + if not args.skip_prefill: + generate_batch_prefill(bs) if not args.skip_decode: generate_batch_decode(bs) bsizes.append(bs) diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index b30acc026..cc1c96d8c 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -15,6 +15,8 @@ import torch +from iree.turbine.aot import DeviceAffinity + from ..layers import * from ..types import * @@ -156,9 +158,10 @@ def prefill(self): token_ids = self.token_ids if model.config.tensor_parallelism_size != 1: tp = model.config.tensor_parallelism_size - token_ids = replicate(token_ids, tp) - attention_mask = replicate(attention_mask, tp) - seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp) + devices = [DeviceAffinity(i) for i in range(tp)] + token_ids = replicate(token_ids, devices) + attention_mask = replicate(attention_mask, devices) + seq_block_ids_tensor = replicate(seq_block_ids_tensor, devices) logits = model.prefill( token_ids, @@ -199,10 +202,11 @@ def decode(self): if model.config.tensor_parallelism_size != 1: tp = model.config.tensor_parallelism_size - self.next_tokens = replicate(self.next_tokens, tp) - start_positions = replicate(start_positions, tp) - seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp) - decode_attention_mask = replicate(decode_attention_mask, tp) + devices = [DeviceAffinity(i) for i in range(tp)] + self.next_tokens = replicate(self.next_tokens, devices) + start_positions = replicate(start_positions, devices) + seq_block_ids_tensor = replicate(seq_block_ids_tensor, devices) + decode_attention_mask = replicate(decode_attention_mask, devices) logits = model.decode( self.next_tokens, @@ -279,8 +283,7 @@ def main(): tensor_parallelism_size=args.tensor_parallelism_size, fake_quant=args.fake_quant, ) - if config.tensor_parallelism_size > 1: - dataset.root_theta = shard_theta(dataset.root_theta, config) + affinities = [DeviceAffinity(i) for i in range(args.tensor_parallelism_size)] if config.hp.expert_count: if config.hp.model_arch == "grok": @@ -288,7 +291,7 @@ def main(): else: model = PagedMixtralModelV1(dataset.root_theta, config) else: - model = PagedLlamaModelV1(dataset.root_theta, config) + model = PagedLlamaModelV1(dataset.root_theta, config, devices=affinities) if args.save_intermediates_path: from ..utils.patching import SaveModuleResultTensorsPatch diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index 46e94ff90..82973b576 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -292,7 +292,9 @@ def unflatten_page_table( shards = [ shard.unflatten(1, self.sub_page_dims) for shard in page_slab.shards ] - return SplitPrimitiveTensor(ts=shards, shard_dim=4) + return SplitPrimitiveTensor( + ts=shards, shard_dim=4, devices=page_slab.devices + ) def shard_state( self, state: List[torch.Tensor] diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index 99ecf5057..3c1411fce 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -27,10 +27,11 @@ def __init__( use_hf: bool = False, static_tables: bool = False, use_table: bool = True, - tensor_parallelism_size: int = 1, + devices: list | None = None, ): super().__init__() self.device = device + self.devices = devices self.rope_dimension_count = rope_dimension_count self.max_seqlen = max_seqlen self.use_hf = use_hf @@ -38,7 +39,7 @@ def __init__( self.use_table = use_table self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0 - self.tensor_parallelism_size = tensor_parallelism_size + self.devices = devices if static_tables: ops.module_register_buffer( self, "static_rotary_embed_table", self._create_rotary_embed_table() @@ -80,7 +81,9 @@ def forward( ) for xt_shard, rotary_shard in zip(xt.shards, rotary_shards) ] - xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim) + xt = SplitPrimitiveTensor( + ts=xt_shards, shard_dim=xt.shard_dim, devices=xt.devices + ) return xt else: return self.forward_unsharded( @@ -189,7 +192,7 @@ def compute_batch_mask( self._compute_rotary_embed_table(s.flatten()).unflatten(0, shape) for s in positions_seq.shards ] - freqs_cis = ReplicatedTensor(ts=ts) + freqs_cis = ReplicatedTensor(ts=ts, devices=positions_seq.devices) else: freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten()) freqs_cis = freqs_cis.unflatten(0, shape) @@ -215,7 +218,9 @@ def apply_batched_mask( ) for xt_shard, mask_shard in zip(xt.shards, mask.shards) ] - xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim) + xt = SplitPrimitiveTensor( + ts=xt_shards, shard_dim=xt.shard_dim, devices=xt.devices + ) return xt def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor): @@ -259,8 +264,8 @@ def _create_rotary_embed_table(self): return self._replicate(freqs_cis) def _replicate(self, t): - if self.tensor_parallelism_size > 1: + if self.devices is not None and len(self.devices) > 1: # Replicate across all devices, the data is not a lot and the computation is cheap. - t = ops.replicate(t, self.tensor_parallelism_size) + t = ops.replicate(t, self.devices) return t diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 0a9a6f1c3..fa481328d 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -62,7 +62,7 @@ class PagedLlamaModelV1(BaseCausalLMModel): unsharded result or chain it with other tensor-parallel operations. """ - def __init__(self, theta: Theta, config: LlamaModelConfig): + def __init__(self, theta: Theta, config: LlamaModelConfig, devices: list): hp = config.hp super().__init__( theta, @@ -91,9 +91,9 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): rope_freq_base=hp.rope_freq_base, max_seqlen=hp.context_length, device=self.device, + devices=devices, use_hf=self.use_hf, static_tables=config.static_tables, - tensor_parallelism_size=config.tensor_parallelism_size, ), ) self.add_module( @@ -238,8 +238,12 @@ def decode( ) for _ in range(self.config.tensor_parallelism_size) ] - xk_temp = SplitPrimitiveTensor(ts=xk_temp_shard, shard_dim=2) - xv_temp = SplitPrimitiveTensor(ts=xv_temp_shard, shard_dim=2) + xk_temp = SplitPrimitiveTensor( + ts=xk_temp_shard, shard_dim=2, devices=tokens.devices + ) + xv_temp = SplitPrimitiveTensor( + ts=xv_temp_shard, shard_dim=2, devices=tokens.devices + ) h = self.token_embedding(tokens) self.trace_tensor("llama.token_embedding", h) diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index 015e88a4b..45505648a 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -12,6 +12,8 @@ import math import functools +from iree.turbine.aot import DeviceAffinity + from ..types import ( AnyTensor, DefaultPrimitiveTensor, @@ -42,14 +44,14 @@ def all_gather_split( shards = [ cat( [ - shard if i == j else transfer_to_logical_device(shard, i) - for j, shard in enumerate(input.shards) + transfer_to_logical_device(shard, device.ordinal) + for shard in input.shards ], dim=dim, ) - for i in range(input.shard_count) + for device in input.devices ] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=input.devices) @all_reduce.override(AllOfType(SplitPrimitiveTensor, UnreducedTensor)) @@ -63,23 +65,25 @@ def all_reduce_split_or_unreduced( functools.reduce( lambda x, y: elementwise(torch.add, x, y), [ - shard if i == j else transfer_to_logical_device(shard, i) - for j, shard in enumerate(input.shards) + transfer_to_logical_device(shard, device.ordinal) + for shard in input.shards ], ) - for i in range(input.shard_count) + for device in input.devices ] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=input.devices) @cat.override(AllOfType(ReplicatedTensor)) def cat_replicated(tensors: Sequence[ReplicatedTensor], dim: int) -> ReplicatedTensor: assert len(tensors) > 0 + devices = tensors[0].devices shard_count = tensors[0].shard_count assert all([t.shard_count == shard_count for t in tensors]) + assert all([t.devices == devices for t in tensors]) shards = [cat(shards, dim) for shards in zip(*[t.shards for t in tensors])] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=devices) @cat.override(AllOfType(SplitPrimitiveTensor)) @@ -97,9 +101,12 @@ def cat_split( shard_dim = tensors[0].shard_dim shard_count = tensors[0].shard_count + devices = tensors[0].devices + for t in tensors: + assert t.devices == tensors[0].devices if dim != shard_dim: shards = [cat(shards, dim) for shards in zip(*[t.shards for t in tensors])] - return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim, devices=devices) else: # TODO: implement efficient cat along split dim. # This would probably result in doing the concatenation on one device. @@ -191,6 +198,7 @@ def conv2d_replicated_input_split_weight_and_bias( bias is None or weight.shard_dim == 0 and bias.shard_dim == 0 ), "Only sharding of output channel dimension is supported" assert groups == 1 + assert input.devices == weight.devices shards = [ conv2d( @@ -208,7 +216,7 @@ def conv2d_replicated_input_split_weight_and_bias( [None] * weight.shard_count if bias is None else bias.shards, ) ] - return SplitPrimitiveTensor(shard_dim=1, ts=shards) + return SplitPrimitiveTensor(shard_dim=1, ts=shards, devices=input.devices) conv2d.override( @@ -252,7 +260,7 @@ def conv2d_split_weight_and_bias( [None] * weight.shard_count if bias is None else bias.shards, ) ] - return SplitPrimitiveTensor(shard_dim=1, ts=shards) + return SplitPrimitiveTensor(shard_dim=1, ts=shards, devices=weight.devices) else: assert False, "Unsupported, TODO: handle split channels in input" @@ -270,13 +278,15 @@ def conv2d_split_weight_and_bias( @elementwise.override(ReplicatedTensor) def replicated_elementwise_unary(operator, x: ReplicatedTensor, *args, **kwargs): partials = [operator(unbox_tensor(pt), *args, **kwargs) for pt in x.shards] - return ReplicatedTensor(ts=partials) + return ReplicatedTensor(ts=partials, devices=x.devices) @elementwise.override(SplitPrimitiveTensor) def split_elementwise_unary(operator, x: SplitPrimitiveTensor, *args, **kwargs): partials = [operator(unbox_tensor(pt), *args, **kwargs) for pt in x.shards] - return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials) + return SplitPrimitiveTensor( + shard_dim=x.shard_dim, shape=x.shape, ts=partials, devices=x.devices + ) @elementwise.override(ReplicatedTensor, ReplicatedTensor) @@ -284,11 +294,12 @@ def replicated_elementwise_binary( operator, x: ReplicatedTensor, y: ReplicatedTensor, *args, **kwargs ): assert x.shard_count == y.shard_count + assert x.devices == y.devices shards = [ operator(unbox_tensor(shard_x), unbox_tensor(shard_y), *args, **kwargs) for shard_x, shard_y in zip(x.shards, y.shards) ] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=x.devices) @elementwise.override(SplitPrimitiveTensor, SplitPrimitiveTensor) @@ -298,6 +309,7 @@ def split_elementwise_binary( assert x.shard_count == y.shard_count x_shard_dim, y_shard_dim = broadcast_dims([x.shard_dim, y.shard_dim], [x, y]) assert x_shard_dim == y_shard_dim + assert x.devices == y.devices pt_xs = [unbox_tensor(pt) for pt in x.shards] pt_ys = [unbox_tensor(pt) for pt in y.shards] partials = [ @@ -307,6 +319,7 @@ def split_elementwise_binary( shard_dim=x.shard_dim, shape=torch.broadcast_shapes(x.shape, y.shape), ts=partials, + devices=x.devices, ) @@ -316,7 +329,9 @@ def elementwise_binary_split_lhs_scalar_rhs( ): pt_xs = [unbox_tensor(pt) for pt in x.shards] partials = [operator(pt_x, y, *args, **kwargs) for pt_x in pt_xs] - return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials) + return SplitPrimitiveTensor( + shard_dim=x.shard_dim, shape=x.shape, ts=partials, devices=x.devices + ) @elementwise.override(SplitPrimitiveTensor, Tensor) @@ -345,6 +360,7 @@ def elementwise_binary_split_lhs_replicated_rhs( operator, x: SplitPrimitiveTensor, y: ReplicatedTensor, *args, **kwargs ): assert len(y.shape) > 0, "0-rank not supported" + assert x.devices == y.devices if x.shard_count != y.shard_count: raise ValueError( f"Operands' number of shards not equal ({x.shard_count} != {y.shard_count})" @@ -360,7 +376,9 @@ def elementwise_binary_split_lhs_replicated_rhs( elementwise(operator, x_shard, y_shard) for x_shard, y_shard in zip(x.shards, y.shards) ] - return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim_in_res) + return SplitPrimitiveTensor( + ts=shards, shard_dim=shard_dim_in_res, devices=x.devices + ) y_sharded = reshard_like(y, like=x) return elementwise(operator, x, y_sharded, *args, **kwargs) @@ -402,13 +420,14 @@ def embedding_lookup_default( dtype: Optional[torch.dtype], ): assert input.shard_count == embedding_matrix.shard_count + assert input.devices == embedding_matrix.devices shards = [ embedding_lookup(input_shard, embedding_matrix_shard, dtype) for input_shard, embedding_matrix_shard in zip( input.shards, embedding_matrix.shards ) ] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=input.devices) @equal.override(ReplicatedTensor) @@ -446,7 +465,9 @@ def set_element(l: List, idx: int, el: Any) -> List: ) for shard in tensor.shards ] - return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=tensor.shard_dim, devices=tensor.devices + ) @flatten.override(ReplicatedTensor) @@ -454,7 +475,7 @@ def flatten_replicated( input: ReplicatedTensor, start_dim: int, end_dim: int ) -> ReplicatedTensor: shards = [shard.flatten(start_dim, end_dim) for shard in input.shards] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=input.devices) @flatten.override(SplitPrimitiveTensor) @@ -477,7 +498,7 @@ def flatten_split( if input.shard_dim <= start_dim else input.shard_dim - (end_dim_resolved - start_dim) ) - return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim, devices=input.devices) @gather.override(ReplicatedTensor, ReplicatedTensor) @@ -485,11 +506,12 @@ def gather_replicated( input: ReplicatedTensor, dim: int, index: ReplicatedTensor ) -> Tensor: assert input.shard_count == index.shard_count + assert input.devices == index.devices shards = [ gather(input_shard, dim, index_shard) for input_shard, index_shard in zip(input.shards, index.shards) ] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=input.devices) @group_norm_affine.override( @@ -509,7 +531,7 @@ def shareded_group_norm_affine(input, weight, bias, *, num_groups, eps): for x, w, b in zip(input.shards, weight.shards, bias.shards) ] - return SplitPrimitiveTensor(shard_dim=1, ts=result_shards) + return SplitPrimitiveTensor(shard_dim=1, ts=result_shards, devices=input.devices) @index_copy_.override(SplitPrimitiveTensor, ReplicatedTensor, SplitPrimitiveTensor) @@ -546,7 +568,7 @@ def index_put__split( ) -> SplitPrimitiveTensor: # TODO: verify that the values split dimension is not being indexed or implement # this case. - indices = [replicate(idx, count=inout.shard_count) for idx in indices] + indices = [replicate(idx, devices=inout.devices) for idx in indices] for i, shard in enumerate(inout.shards): shard_indices = [idx.shards[i] for idx in indices] shard.index_put_(shard_indices, values.shards[i]) @@ -560,11 +582,12 @@ def index_select_replicated( index: ReplicatedTensor, ) -> ReplicatedTensor: assert tensor.shard_count == index.shard_count + assert tensor.devices == index.devices shards = [ index_select(tensor_shard, dim, index_shard) for tensor_shard, index_shard in zip(tensor.shards, index.shards) ] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=tensor.devices) @index_select.override(SplitPrimitiveTensor, ReplicatedTensor) @@ -581,7 +604,9 @@ def index_select_split_replicated( index_select(tensor_shard, dim, index_shard) for tensor_shard, index_shard in zip(tensor.shards, index.shards) ] - return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=tensor.shard_dim, devices=tensor.devices + ) @interpolate.override(ReplicatedTensor) @@ -606,7 +631,7 @@ def interpolate_replicated( ) for shard in input.shards ] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=input.devices) @interpolate.override(SplitPrimitiveTensor) @@ -632,7 +657,9 @@ def interpolate_split_batch_or_channel( ) for shard in input.shards ] - return SplitPrimitiveTensor(ts=shards, shard_dim=input.shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=input.shard_dim, devices=input.devices + ) @layer_norm.override(SplitPrimitiveTensor, Tensor, Tensor) @@ -641,7 +668,9 @@ def layer_norm_default(input, weight, bias, *, eps): weight.shape ) shards = [layer_norm(shard, weight, bias, eps=eps) for shard in input.shards] - return SplitPrimitiveTensor(shard_dim=input.shard_dim, ts=shards) + return SplitPrimitiveTensor( + shard_dim=input.shard_dim, ts=shards, devices=input.devices + ) # Linear @@ -695,7 +724,9 @@ def matmul_replicated_lhs_split_rhs( matmul(lhs_shard, rhs_shard) for (lhs_shard, rhs_shard) in zip(lhs.shards, rhs.shards) ] - return SplitPrimitiveTensor(ts=shards, shard_dim=len(lhs.shape) - 2 + rhs.shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=len(lhs.shape) - 2 + rhs.shard_dim, devices=rhs.devices + ) @matmul.override(SplitPrimitiveTensor, Tensor) @@ -707,7 +738,7 @@ def matmul_split_lhs( shards = [ matmul(lhs_shard, rhs, transpose_rhs=transpose_rhs) for lhs_shard in lhs.shards ] - return SplitPrimitiveTensor(shard_dim=lhs.shard_dim, ts=shards) + return SplitPrimitiveTensor(shard_dim=lhs.shard_dim, ts=shards, devices=lhs.devices) @matmul.override(Tensor, SplitPrimitiveTensor) @@ -732,7 +763,9 @@ def matmul_split_rhs( for partial_rhs in rhs.shards ] # The partial is split columnwise (last dim). - return SplitPrimitiveTensor(shard_dim=len(lhs.shape) - 1, ts=partials) + return SplitPrimitiveTensor( + shard_dim=len(lhs.shape) - 1, ts=partials, devices=rhs.devices + ) @matmul.override(SplitPrimitiveTensor, ReplicatedTensor) @@ -747,13 +780,14 @@ def matmul_split_lhs_replicated_rhs( matmul(lhs_shard, rhs_shard) for (lhs_shard, rhs_shard) in zip(lhs.shards, rhs.shards) ] - return SplitPrimitiveTensor(ts=shards, shard_dim=lhs.shard_dim) + return SplitPrimitiveTensor(ts=shards, shard_dim=lhs.shard_dim, devices=lhs.devices) @matmul.override(SplitPrimitiveTensor, SplitPrimitiveTensor) def matmul_split( lhs: SplitPrimitiveTensor, rhs: SplitPrimitiveTensor, *, transpose_rhs: bool ) -> UnreducedTensor | SplitPrimitiveTensor: + assert lhs.devices == rhs.devices if lhs.shard_count != rhs.shard_count: raise ValueError( f"Cannot matmul split tensors of different shard_count: " @@ -771,7 +805,7 @@ def matmul_split( matmul(partial_lhs, partial_rhs) for partial_lhs, partial_rhs in zip(lhs.shards, rhs.shards) ] - return UnreducedTensor(ts=partials) + return UnreducedTensor(ts=partials, devices=lhs.devices) is_batched_matmul = len(lhs.shape) > 2 or len(rhs.shape) > 2 if ( @@ -784,7 +818,9 @@ def matmul_split( matmul(lhs_shard, rhs_shard) for lhs_shard, rhs_shard in zip(lhs.shards, rhs.shards) ] - return SplitPrimitiveTensor(ts=shards, shard_dim=lhs.shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=lhs.shard_dim, devices=lhs.devices + ) # -1 for missing parallel dim. lhs_parallel_dim = len(lhs.shape) - 2 @@ -812,6 +848,8 @@ def matmul_split( Optional[ReplicatedTensor], ) def scaled_dot_product_attention_sharded(q, k, v, a, is_causal, scale) -> Tensor: + assert q.devices == k.devices + assert q.devices == v.devices if q.shard_count != k.shard_count or q.shard_count != v.shard_count: raise ValueError("Incompatible number of shards for qkv") @@ -837,7 +875,9 @@ def scaled_dot_product_attention_sharded(q, k, v, a, is_causal, scale) -> Tensor ) output_shards.append(o_s) - return SplitPrimitiveTensor(ts=output_shards, shard_dim=q.shard_dim) + return SplitPrimitiveTensor( + ts=output_shards, shard_dim=q.shard_dim, devices=q.devices + ) @mean.override(ReplicatedTensor) @@ -849,7 +889,7 @@ def mean_replicated( dtype: torch.dtype, ) -> None: shards = [mean(shard, dim=dim, keepdim=keepdim, dtype=dtype) for shard in x.shards] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=x.devices) @module_register_buffer.override(torch.nn.Module, ShardedTensor) @@ -865,48 +905,56 @@ def module_register_buffer_sharded( def permute_split(tensor: SplitPrimitiveTensor, dims: List[int]): permuted_shards = [permute(shard, dims) for shard in tensor.shards] permuted_shard_dim = dims[tensor.shard_dim] - return SplitPrimitiveTensor(ts=permuted_shards, shard_dim=permuted_shard_dim) + return SplitPrimitiveTensor( + ts=permuted_shards, shard_dim=permuted_shard_dim, devices=tensor.devices + ) @permute.override(ReplicatedTensor) def permute_replicated(tensor: ReplicatedTensor, dims: List[int]): permuted_shards = [permute(shard, dims) for shard in tensor.shards] - return ReplicatedTensor(ts=permuted_shards) + return ReplicatedTensor(ts=permuted_shards, devices=tensor.devices) @repeat.override(ReplicatedTensor) def repeat_replicated(input: ReplicatedTensor, *sizes: List[int]) -> ReplicatedTensor: shards = [repeat(shard, *sizes) for shard in input.shards] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=input.devices) @replicate.override(ReplicatedTensor) -def replicate_replicated(input: ReplicatedTensor, *, count: int) -> ReplicatedTensor: - if input.shard_count != count: - raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") +def replicate_replicated(input: ReplicatedTensor, *, devices: list) -> ReplicatedTensor: + if input.shard_count != len(devices): + raise ValueError( + f"Number of shards not equal ({input.shard_count} != {len(devices)})" + ) return input @replicate.override(SplitPrimitiveTensor) -def replicate_split(input: SplitPrimitiveTensor, *, count: int) -> ReplicatedTensor: - if input.shard_count != count: - raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") +def replicate_split(input: SplitPrimitiveTensor, *, devices: list) -> ReplicatedTensor: + if input.shard_count != len(devices): + raise ValueError( + f"Number of shards not equal ({input.shard_count} != {len(devices)})" + ) return all_gather(input) @replicate.override(UnreducedTensor) -def replicate_unreduced(input: UnreducedTensor, *, count: int) -> ReplicatedTensor: - if input.shard_count != count: - raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") +def replicate_unreduced(input: UnreducedTensor, *, devices: list) -> ReplicatedTensor: + if input.shard_count != len(devices): + raise ValueError( + f"Number of shards not equal ({input.shard_count} != {len(devices)})" + ) return all_reduce(input) @replicate.override(Tensor) -def replicate_unsharded(input, *, count: int) -> ReplicatedTensor: +def replicate_unsharded(input, *, devices: list) -> ReplicatedTensor: torch_input = unbox_tensor(input) # If we have a torch input replicating we can assume we need to transfer: - torch_inputs = [transfer_to_logical_device(torch_input, i) for i in range(count)] - return ReplicatedTensor(ts=torch_inputs) + torch_inputs = [transfer_to_logical_device(torch_input, d.ordinal) for d in devices] + return ReplicatedTensor(ts=torch_inputs, devices=devices) @reshape.override(SplitPrimitiveTensor) @@ -975,7 +1023,8 @@ def reshard_all_to_unsharded(input: AnyTensor, spec: sharding.Unsharded) -> Tens def reshard_all_to_replicated( input: AnyTensor, spec: sharding.Replicated ) -> ReplicatedTensor: - return replicate(input, spec.shard_count) + device = [DeviceAffinity(i) for i in range(spec.shard_count)] + return replicate(input, device) @reshard_split.override(Tensor) @@ -1018,7 +1067,7 @@ def slice_range_along_dim(dim: int, start: int, end: int): ] for shard_idx, shard in enumerate(input.shards) ] - return SplitPrimitiveTensor(ts=shards, shard_dim=dim) + return SplitPrimitiveTensor(ts=shards, shard_dim=dim, devices=input.devices) @reshard_like.override(Tensor, SplitPrimitiveTensor) @@ -1044,7 +1093,7 @@ def reshard_like_unsharded_to_replicated( tensor, like: ReplicatedTensor ) -> ReplicatedTensor: torch_tensor = unbox_tensor(tensor) - return replicate(torch_tensor, count=like.shard_count) + return replicate(torch_tensor, devices=like.devices) @reshard_like.override(ReplicatedTensor, ReplicatedTensor) @@ -1079,14 +1128,14 @@ def reshard_like_split_to_split( def reshard_like_unreduced_to_replicated( tensor: UnreducedTensor, like: ReplicatedTensor ) -> ReplicatedTensor: - return replicate(tensor, count=like.shard_count) + return replicate(tensor, devices=like.devices) @sharded_cat.override(SplitPrimitiveTensor) def sharded_cat_unsharded(tensor: SplitPrimitiveTensor): shard_ts = [ - transfer_to_logical_device(shard.as_torch(), 0) if i != 0 else shard.as_torch() - for i, shard in enumerate(tensor.shards) + transfer_to_logical_device(shard.as_torch(), tensor.devices[0].ordinal) + for shard in tensor.shards ] return torch.cat(shard_ts, dim=tensor.shard_dim) @@ -1122,20 +1171,25 @@ def softmax_split( ), "Softmax along split dimension is not supported." shards = [softmax(shard, dim=dim, dtype=dtype) for shard in tensor.shards] return SplitPrimitiveTensor( - ts=shards, shard_dim=tensor.shard_dim, shape=tensor.shape + ts=shards, + shard_dim=tensor.shard_dim, + shape=tensor.shape, + devices=tensor.devices, ) @to.override(ReplicatedTensor) def to_replicated(tensor: ReplicatedTensor, *args, **kwargs): shards = [to(shard, *args, **kwargs) for shard in tensor.shards] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=tensor.devices) @to.override(SplitPrimitiveTensor) def to_split(tensor: SplitPrimitiveTensor, *args, **kwargs): shards = [to(shard, *args, **kwargs) for shard in tensor.shards] - return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=tensor.shard_dim, devices=tensor.devices + ) @transpose.override(SplitPrimitiveTensor) @@ -1148,7 +1202,7 @@ def transpose_split( shard_dim = dim1 elif shard_dim == dim1: shard_dim = dim0 - return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim, devices=tensor.devices) @unflatten.override(SplitPrimitiveTensor) @@ -1160,7 +1214,7 @@ def unflatten_split( shard_dim = input.shard_dim if dim < shard_dim: shard_dim += len(sizes) - 1 - return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim, devices=input.devices) @unshard.override(ReplicatedTensor) @@ -1177,8 +1231,7 @@ def unshard_split(input: SplitPrimitiveTensor) -> Tensor: def unshard_unreduced(input: UnreducedTensor) -> Tensor: shards = input.shards shards = [ - shard if i == 0 else transfer_to_logical_device(shard, 0) - for i, shard in enumerate(shards) + transfer_to_logical_device(shard, input.devices[0].ordinal) for shard in shards ] return functools.reduce(lambda x, y: elementwise(torch.add, x, y), shards) @@ -1268,13 +1321,13 @@ def unsqueeze_split(tensor: SplitPrimitiveTensor, dim: int) -> SplitPrimitiveTen dim_resolved = dim if dim >= 0 else dim + len(tensor.shape) + 1 if shard_dim >= dim_resolved: shard_dim += 1 - return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim, devices=tensor.devices) @unsqueeze.override(ReplicatedTensor) def unsqueeze_replicated(tensor: ReplicatedTensor, dim: int) -> SplitPrimitiveTensor: shards = [torch.unsqueeze(unbox_tensor(shard), dim) for shard in tensor.shards] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=tensor.devices) @view.override(SplitPrimitiveTensor) @@ -1303,7 +1356,7 @@ def view_split(tensor: SplitPrimitiveTensor, shape: List[int]) -> SplitPrimitive new_shard_shape = list(shape) new_shard_shape[shard_dim] //= tensor.shard_count shards = [view(shard, new_shard_shape) for shard in tensor.shards] - res = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards) + res = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards, devices=tensor.devices) assert math.prod(res.shape) == math.prod(tensor.shape) return res @@ -1311,22 +1364,26 @@ def view_split(tensor: SplitPrimitiveTensor, shape: List[int]) -> SplitPrimitive @view_as_complex.override(SplitPrimitiveTensor) def view_as_complex_split(tensor: SplitPrimitiveTensor) -> SplitPrimitiveTensor: shards = [view_as_complex(shard) for shard in tensor.shards] - return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=tensor.shard_dim, devices=tensor.devices + ) @view_as_complex.override(ReplicatedTensor) def view_as_complex_rep(tensor: ReplicatedTensor) -> ReplicatedTensor: shards = [view_as_complex(shard) for shard in tensor.shards] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=tensor.devices) @view_as_real.override(SplitPrimitiveTensor) def view_as_real_split(tensor: SplitPrimitiveTensor) -> SplitPrimitiveTensor: shards = [view_as_real(shard) for shard in tensor.shards] - return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=tensor.shard_dim, devices=tensor.devices + ) @view_as_real.override(ReplicatedTensor) def view_as_real_rep(tensor: ReplicatedTensor) -> ReplicatedTensor: shards = [view_as_real(shard) for shard in tensor.shards] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=tensor.devices) diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index cbe959d28..b0504fdfa 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -796,7 +796,7 @@ def _repeat_trampoline( @overridable -def replicate(input: AnyTensor, count: int) -> ShardedTensor: +def replicate(input: AnyTensor, devices: list) -> ShardedTensor: """Replicate across devices. Possibly reshards if required.""" @@ -805,11 +805,11 @@ def replicate(input: AnyTensor, count: int) -> ShardedTensor: @replicate.trampoline def _replicate_trampoline( - d: SignatureDispatcher, input: AnyTensor, count: int + d: SignatureDispatcher, input: AnyTensor, devices: list ) -> ShardedTensor: tensors = (input,) for override in d.find_overrides(tensors): - result = override(input, count=count) + result = override(input, devices=devices) if result is not NotImplemented: return override, result else: diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 153a5d753..b2638d451 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -28,8 +28,10 @@ from torch import Tensor from torch.utils._pytree import register_pytree_node, SequenceKey import torch.utils._pytree + from ..utils.math import ceildiv from iree.turbine.aot import ( + DeviceAffinity, DeviceTensorTrait, ExternalTensorTrait, ) @@ -791,6 +793,7 @@ def __init__( ts: list[torch.Tensor], name: str = UnnamedTensorName, shape: Optional[list[int]], + devices: Optional[list], ): assert len(ts) > 0 assert shard_dim is None or (shard_dim >= 0 and len(ts[0].shape) > shard_dim) @@ -805,6 +808,22 @@ def __init__( for i, t in enumerate(ts) ) + if devices is None: + devices = [DeviceAffinity(i) for i in range(len(self._shards))] + + self._devices: tuple[DeviceAffinity] = tuple(devices) + + for i, t in enumerate(ts): + DeviceTensorTrait(i).set(t) + + def assign_affinities(self, affinities): + assert len(affinities) == len(self._devices) + self._devices = tuple(affinities) + for s, d in zip(self._shards, self._devices): + if isinstance(s, DefaultPrimitiveTensor): + s = s.as_torch() + DeviceTensorTrait(d.ordinal, d.queues).set(s) + @property def shard_count(self) -> int: return len(self._shards) @@ -813,6 +832,10 @@ def shard_count(self) -> int: def shards(self) -> tuple[InferenceTensor]: return self._shards + @property + def devices(self) -> tuple[DeviceAffinity]: + return self._devices + @property def is_replicated(self) -> bool: return False @@ -871,8 +894,6 @@ def create( try: t = raw_tensors[t_name] ts.append(t) - # TODO: this should be changed to tracked device affinity - DeviceTensorTrait(i).set(t) except KeyError as e: raise IOError( f"Missing component tensor '{t_name}' in {raw_tensors.keys()}" @@ -965,6 +986,7 @@ def __init__( shard_count: None | int = None, name: str = UnnamedTensorName, shape: Optional[list[int]] = None, + devices: list | None = None, ): """ If `ts` is a list of tensors, it is interpreted as the shards. @@ -973,16 +995,25 @@ def __init__( will be split along dimension `shard_dim` into `shard_count` number of pieces. """ - if isinstance(ts, torch.Tensor): - from ..ops import transfer_to_logical_device + assert shard_count is None or not isinstance(ts, torch.Tensor) + shard_count = shard_count if shard_count is not None else len(ts) + + if devices is None: + devices = [DeviceAffinity(i) for i in range(shard_count)] + + assert len(devices) == shard_count + + if isinstance(ts, torch.Tensor): assert shard_count is not None ts = ts.split(ceildiv(ts.shape[shard_dim], shard_count), dim=shard_dim) - ts = [transfer_to_logical_device(t, i) for i, t in enumerate(ts)] + + from ..ops import transfer_to_logical_device + + ts = [transfer_to_logical_device(t, d.ordinal) for t, d in zip(ts, devices)] assert len(ts) == shard_count shard_count = None - assert shard_count is None assert len(ts) > 0 first_shape = ts[0].shape assert len(first_shape) > shard_dim @@ -1004,7 +1035,9 @@ def __init__( s == t for i, (s, t) in enumerate(zip(shape, t_shape)) if i != shard_dim ), f"Shape mismatch for non-split dimension for tensor shard {i} with shape {t.shape}" - super().__init__(name=name, ts=ts, shape=shape, shard_dim=shard_dim) + super().__init__( + name=name, ts=ts, shape=shape, shard_dim=shard_dim, devices=devices + ) def _is_slicing_split_dim(self, key): if isinstance( @@ -1072,7 +1105,9 @@ def __getitem__(self, key): # Rank reduction dimension before the split dim. shard_dim -= 1 - return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=shard_dim, devices=self.devices + ) def __setitem__(self, key, value): assert isinstance(value, SplitPrimitiveTensor) @@ -1098,7 +1133,8 @@ class ReplicatedTensor(ShardedTensor): def __init__( self, *, - ts: list[torch.Tensor] | torch.Tensor, + ts: list[torch.Tensor], + devices: None | list[DeviceAffinity] = None, shard_count: None | int = None, name: str = UnnamedTensorName, ): @@ -1108,13 +1144,7 @@ def __init__( If `ts` is a tensor then `shard_count` must be provided and it, will be replicated that many times. """ - if isinstance(ts, torch.Tensor): - assert shard_count is not None - from ..ops import transfer_to_logical_device - - ts = [transfer_to_logical_device(ts, i) for i in range(shard_count)] - shard_count = None - + assert not isinstance(ts, torch.Tensor) assert shard_count is None assert len(ts) > 0 first_shape = ts[0].shape @@ -1134,6 +1164,22 @@ def __init__( for i, t in enumerate(ts) ) + if devices is None: + devices = tuple([DeviceAffinity(i) for i in range(len(ts))]) + + self._devices: tuple[DeviceAffinity] = tuple(devices) + + for d, t in zip(devices, ts): + DeviceTensorTrait(d.ordinal, d.queues).set(t) + + def assign_affinities(self, affinities): + assert len(affinities) == len(self._devices) + self._devices = tuple(affinities) + for s, d in zip(self._shards, self._devices): + if isinstance(s, DefaultPrimitiveTensor): + s = s.as_torch() + DeviceTensorTrait(d.ordinal, d.queues).set(s) + @property def shard_count(self) -> int: return len(self._shards) @@ -1142,6 +1188,10 @@ def shard_count(self) -> int: def shards(self) -> tuple[InferenceTensor]: return self._shards + @property + def devices(self) -> tuple[DeviceAffinity]: + return self._devices + @property def is_replicated(self) -> bool: return True @@ -1188,10 +1238,6 @@ def create( nt = deepcopy(t) ts.append(nt) - # TODO This should be changed to assigned affinities - for i in range(shard_count): - DeviceTensorTrait(i).set(ts[i]) - except KeyError as e: raise IOError(f"Missing component tensor '' in {raw_tensors.keys()}") from e return cls(name=name, ts=ts) @@ -1210,12 +1256,13 @@ def __getitem__(self, key): else: shard_keys.append(k) shards.append(shard[*shard_keys]) - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=self.devices) def __repr__(self): return ( f"ReplicatedTensor({self.name}, {self.shape}, " f"shard_count={len(self._shards)} " + f"devices={self.devices} " f"of {self.shards[0].shape})" ) @@ -1243,13 +1290,14 @@ def __init__( self, *, ts: list[torch.Tensor], + devices: Optional[list] = None, name: str = UnnamedTensorName, shape: Optional[list[int]] = None, ): assert len(ts) > 0 shape = list(ts[0].shape if shape is None else shape) assert all(shape == list(t.shape) for t in ts) - super().__init__(name=name, ts=ts, shape=shape, shard_dim=None) + super().__init__(name=name, ts=ts, shape=shape, shard_dim=None, devices=devices) def flatten_tensor_tree( diff --git a/sharktank/sharktank/types/theta.py b/sharktank/sharktank/types/theta.py index 021925169..8c6d235dd 100644 --- a/sharktank/sharktank/types/theta.py +++ b/sharktank/sharktank/types/theta.py @@ -25,8 +25,6 @@ from .tensors import ( InferenceTensor, - PrimitiveTensor, - QuantizedTensor, InferenceTensorMetadata, DefaultPrimitiveTensor, REGISTERED_INFERENCE_TENSOR_CLASSES, diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index 99917c2d3..c90e03cea 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -69,9 +69,14 @@ def add_model_options(parser: argparse.ArgumentParser): default="decomposed", choices=["decomposed", "torch"], ) + parser.add_argument( + "--skip-prefill", + help="Skips exporting prefill", + action="store_true", + ) parser.add_argument( "--skip-decode", - help="Enables prefill only, skips decode", + help="Skips exporting decode", action="store_true", ) From 9a061e641d2f21033cc2d1b6b47e6a61480e03b9 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 11 Dec 2024 18:07:36 -0800 Subject: [PATCH 2/2] fix tests --- .../examples/sharding/export_ffn_net.py | 6 ++- sharktank/sharktank/models/llama/llama.py | 7 +++- sharktank/sharktank/ops/sharded_impls.py | 42 +++++++++++++++---- sharktank/sharktank/ops/signatures.py | 8 ++-- sharktank/sharktank/types/tensors.py | 12 +++--- sharktank/tests/layers/kv_cache_test.py | 11 +++-- .../layers/sharded_paged_kv_cache_test.py | 16 ++++--- .../layers/sharded_rotary_embedding_test.py | 6 ++- .../tests/models/llama/sharded_llama_test.py | 2 +- sharktank/tests/ops/sharded_test.py | 6 +-- sharktank/tests/types/tensors_test.py | 4 +- 11 files changed, 83 insertions(+), 37 deletions(-) diff --git a/sharktank/sharktank/examples/sharding/export_ffn_net.py b/sharktank/sharktank/examples/sharding/export_ffn_net.py index f261a92e1..7d1b1a2be 100644 --- a/sharktank/sharktank/examples/sharding/export_ffn_net.py +++ b/sharktank/sharktank/examples/sharding/export_ffn_net.py @@ -17,6 +17,8 @@ import torch import torch.nn as nn +from iree.turbine.aot import DeviceAffinity + from ...layers import * from ... import ops from ...types import * @@ -50,7 +52,9 @@ def forward(self, x: torch.Tensor): ffn_gate_weight = self.theta.tensor("ffn_gate", "weight") ffn_up_weight = self.theta.tensor("ffn_up", "weight") ffn_down_weight = self.theta.tensor("ffn_down", "weight") - x = ops.replicate(x, count=ffn_gate_weight.shard_count) + + devices = [DeviceAffinity(i) for i in range(ffn_down_weight.shard_count)] + x = ops.replicate(x, devices=devices) ffn_gate = ops.elementwise( torch.nn.functional.silu, ops.linear(x, ffn_gate_weight) ) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index fa481328d..367082f48 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -9,6 +9,8 @@ from dataclasses import dataclass from typing import Union +from iree.turbine.aot import DeviceAffinity + import torch import torch.nn as nn import torch.nn.functional as F @@ -62,7 +64,7 @@ class PagedLlamaModelV1(BaseCausalLMModel): unsharded result or chain it with other tensor-parallel operations. """ - def __init__(self, theta: Theta, config: LlamaModelConfig, devices: list): + def __init__(self, theta: Theta, config: LlamaModelConfig, devices: list = None): hp = config.hp super().__init__( theta, @@ -80,6 +82,9 @@ def __init__(self, theta: Theta, config: LlamaModelConfig, devices: list): self.use_hf = config.use_hf self.attention_kernel = config.attention_kernel + if devices is None: + devices = [DeviceAffinity(i) for i in range(config.tensor_parallelism_size)] + self.add_module( "token_embedding", TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype), diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index 45505648a..00ebd0138 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -923,38 +923,62 @@ def repeat_replicated(input: ReplicatedTensor, *sizes: List[int]) -> ReplicatedT @replicate.override(ReplicatedTensor) -def replicate_replicated(input: ReplicatedTensor, *, devices: list) -> ReplicatedTensor: - if input.shard_count != len(devices): +def replicate_replicated( + input: ReplicatedTensor, *, devices: list, count: int +) -> ReplicatedTensor: + if devices is not None and input.shard_count != len(devices): raise ValueError( f"Number of shards not equal ({input.shard_count} != {len(devices)})" ) + if count is not None and input.shard_count != count: + raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") return input @replicate.override(SplitPrimitiveTensor) -def replicate_split(input: SplitPrimitiveTensor, *, devices: list) -> ReplicatedTensor: - if input.shard_count != len(devices): +def replicate_split( + input: SplitPrimitiveTensor, *, devices: list, count: int +) -> ReplicatedTensor: + if devices is not None and input.shard_count != len(devices): raise ValueError( f"Number of shards not equal ({input.shard_count} != {len(devices)})" ) + if count is not None and input.shard_count != count: + raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") return all_gather(input) @replicate.override(UnreducedTensor) -def replicate_unreduced(input: UnreducedTensor, *, devices: list) -> ReplicatedTensor: - if input.shard_count != len(devices): +def replicate_unreduced( + input: UnreducedTensor, *, devices: list, count: int +) -> ReplicatedTensor: + if devices is not None and input.shard_count != len(devices): raise ValueError( f"Number of shards not equal ({input.shard_count} != {len(devices)})" ) + if count is not None and input.shard_count != count: + raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") return all_reduce(input) @replicate.override(Tensor) -def replicate_unsharded(input, *, devices: list) -> ReplicatedTensor: +def replicate_unsharded(input, *, devices: list, count: int) -> ReplicatedTensor: torch_input = unbox_tensor(input) # If we have a torch input replicating we can assume we need to transfer: - torch_inputs = [transfer_to_logical_device(torch_input, d.ordinal) for d in devices] - return ReplicatedTensor(ts=torch_inputs, devices=devices) + if devices is not None: + torch_inputs = [ + transfer_to_logical_device(torch_input, d.ordinal) for d in devices + ] + return ReplicatedTensor(ts=torch_inputs, devices=devices) + + if count is not None: + devices = [DeviceAffinity(i) for i in range(count)] + torch_inputs = [ + transfer_to_logical_device(torch_input, i) for i in range(count) + ] + return ReplicatedTensor(ts=torch_inputs, devices=devices) + + raise ValueError(f"Devices or count is required") @reshape.override(SplitPrimitiveTensor) diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index b0504fdfa..c5dadb698 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -796,7 +796,9 @@ def _repeat_trampoline( @overridable -def replicate(input: AnyTensor, devices: list) -> ShardedTensor: +def replicate( + input: AnyTensor, devices: list = None, count: int = None +) -> ShardedTensor: """Replicate across devices. Possibly reshards if required.""" @@ -805,11 +807,11 @@ def replicate(input: AnyTensor, devices: list) -> ShardedTensor: @replicate.trampoline def _replicate_trampoline( - d: SignatureDispatcher, input: AnyTensor, devices: list + d: SignatureDispatcher, input: AnyTensor, devices: list = None, count: int = None ) -> ShardedTensor: tensors = (input,) for override in d.find_overrides(tensors): - result = override(input, devices=devices) + result = override(input, devices=devices, count=count) if result is not NotImplemented: return override, result else: diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index b2638d451..c4fad275f 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -813,9 +813,6 @@ def __init__( self._devices: tuple[DeviceAffinity] = tuple(devices) - for i, t in enumerate(ts): - DeviceTensorTrait(i).set(t) - def assign_affinities(self, affinities): assert len(affinities) == len(self._devices) self._devices = tuple(affinities) @@ -893,6 +890,7 @@ def create( t_name = str(i) try: t = raw_tensors[t_name] + DeviceTensorTrait(i).set(t) ts.append(t) except KeyError as e: raise IOError( @@ -996,7 +994,7 @@ def __init__( number of pieces. """ - assert shard_count is None or not isinstance(ts, torch.Tensor) + assert shard_count is None or isinstance(ts, torch.Tensor) shard_count = shard_count if shard_count is not None else len(ts) if devices is None: @@ -1169,9 +1167,6 @@ def __init__( self._devices: tuple[DeviceAffinity] = tuple(devices) - for d, t in zip(devices, ts): - DeviceTensorTrait(d.ordinal, d.queues).set(t) - def assign_affinities(self, affinities): assert len(affinities) == len(self._devices) self._devices = tuple(affinities) @@ -1238,6 +1233,9 @@ def create( nt = deepcopy(t) ts.append(nt) + for i, t in enumerate(ts): + DeviceTensorTrait(i).set(t) + except KeyError as e: raise IOError(f"Missing component tensor '' in {raw_tensors.keys()}") from e return cls(name=name, ts=ts) diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py index 65b42c986..228535ee8 100644 --- a/sharktank/tests/layers/kv_cache_test.py +++ b/sharktank/tests/layers/kv_cache_test.py @@ -8,6 +8,7 @@ import torch +from iree.turbine.aot import DeviceAffinity from sharktank.ops import replicate, reshard_split, unshard from sharktank.layers import * from sharktank.types import * @@ -148,6 +149,8 @@ def test_sharded_direct(): write_seq_length = seq_length - 5 + devices = [DeviceAffinity(i) for i in range(shard_count)] + # Write a prefill in: write_ones = reshard_split( torch.full( @@ -204,7 +207,7 @@ def test_sharded_direct(): ) write_pos = replicate( - torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count + torch.full((bs,), write_seq_length, dtype=torch.int64), devices ) cache.write_timestep( allocation, @@ -379,11 +382,13 @@ def test_sharded_paged(): device=None, ) + devices = [DeviceAffinity(i) for i in range(shard_count)] + write_seq_length = seq_length - 4 page_count = bs * seq_length // block_seq_stride page_ids = torch.arange(page_count, dtype=torch.int64) page_ids = page_ids.view(bs, seq_length // block_seq_stride) - page_ids = replicate(page_ids, shard_count) + page_ids = replicate(page_ids, devices=devices) write_page_ids = page_ids[:, : write_seq_length // block_seq_stride] allocation = cache.allocate(page_count=page_count) @@ -458,7 +463,7 @@ def test_sharded_paged(): ) write_pos = replicate( - torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count + torch.full((bs,), write_seq_length, dtype=torch.int64), devices ) cache.write_timestep( diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py index d7b6a0b33..c54b20459 100644 --- a/sharktank/tests/layers/sharded_paged_kv_cache_test.py +++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py @@ -5,13 +5,16 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import unittest -from sharktank.layers import PagedKVCache import torch -from sharktank.utils import iterables_equal from copy import deepcopy from typing import List, Tuple + +from iree.turbine.aot import DeviceAffinity + from sharktank import ops +from sharktank.layers import PagedKVCache from sharktank.types import SplitPrimitiveTensor +from sharktank.utils import iterables_equal class ShardedPagedKVCacheTest(unittest.TestCase): @@ -31,6 +34,7 @@ def setUp(self): self.batch_size = 11 self.block_seq_len = 2 self.max_seq_len = self.block_seq_len * self.block_seq_stride + self.devices = [DeviceAffinity(i) for i in range(self.shard_count)] self.cache = PagedKVCache( transformer_block_count=self.transformer_block_count, @@ -131,7 +135,7 @@ def testRead(self): for t in read_into_partitions_snapshot ] ) - sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) + sharded_page_ids = ops.replicate(page_ids, devices=self.devices) self.sharded_cache.read( state=sharded_cache_state, read_into_partitions=sharded_read_into_partitions, @@ -179,8 +183,8 @@ def testWriteTimestep(self): for t in cache_partitions ] ) - sharded_seq_positions = ops.replicate(seq_positions, count=self.shard_count) - sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) + sharded_seq_positions = ops.replicate(seq_positions, devices=self.devices) + sharded_page_ids = ops.replicate(page_ids, devices=self.devices) self.sharded_cache.write_timestep( state=sharded_cache_state, cache_partitions=sharded_cache_partitions, @@ -224,7 +228,7 @@ def testWrite(self): for t in cache_partitions ] ) - sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) + sharded_page_ids = ops.replicate(page_ids, devices=self.devices) self.sharded_cache.write( state=sharded_cache_state, cache_partitions=sharded_cache_partitions, diff --git a/sharktank/tests/layers/sharded_rotary_embedding_test.py b/sharktank/tests/layers/sharded_rotary_embedding_test.py index f24b8313a..2d2d95691 100644 --- a/sharktank/tests/layers/sharded_rotary_embedding_test.py +++ b/sharktank/tests/layers/sharded_rotary_embedding_test.py @@ -7,6 +7,8 @@ import torch +from iree.turbine.aot import DeviceAffinity + from sharktank.layers import RotaryEmbeddingLayer from sharktank import ops from sharktank.types import ( @@ -27,6 +29,8 @@ def test_sharded_rotary_table(): max_seqlen = 128 rope_freq_base = None + devices = [DeviceAffinity(i) for i in range(4)] + # First we setup and get the default rotary embedding layer xq = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float) xk = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float) @@ -45,7 +49,7 @@ def test_sharded_rotary_table(): rope_dimension_count=rope_dims, max_seqlen=max_seqlen, rope_freq_base=rope_freq_base, - tensor_parallelism_size=4, + devices=devices, ) sq = shard_layer(xt=xq, start_index=0) sk = shard_layer(xt=xk, start_index=0) diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py index 386061731..b1d0aa745 100644 --- a/sharktank/tests/models/llama/sharded_llama_test.py +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -6,7 +6,7 @@ import unittest import pytest -from typing import Any, List, Tuple, OrderedDict +from typing import Any, Tuple, OrderedDict from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 import sharktank.ops as ops from sharktank.types import unbox_tensor, Dataset, UnreducedTensor, SplitPrimitiveTensor diff --git a/sharktank/tests/ops/sharded_test.py b/sharktank/tests/ops/sharded_test.py index e5efaa948..adfb38359 100644 --- a/sharktank/tests/ops/sharded_test.py +++ b/sharktank/tests/ops/sharded_test.py @@ -630,7 +630,7 @@ def testAttentionShardedBatchMask(self): q_s = SplitPrimitiveTensor(shard_dim=0, ts=q.split(1, dim=0)) k_s = SplitPrimitiveTensor(shard_dim=0, ts=k.split(1, dim=0)) v_s = SplitPrimitiveTensor(shard_dim=0, ts=v.split(1, dim=0)) - a_s = ReplicatedTensor(ts=a, shard_count=4) + a_s = ReplicatedTensor(ts=[a] * 4) expected_result = ops.scaled_dot_product_attention( q, k, v, a=a, is_causal=False @@ -754,7 +754,7 @@ def testShardedLhsReplcatedRhs(self): expected_result = torch.matmul(a, b) shard_count = 3 a_sharded = SplitPrimitiveTensor(ts=a, shard_dim=1, shard_count=shard_count) - b_sharded = ReplicatedTensor(ts=b, shard_count=shard_count) + b_sharded = ReplicatedTensor(ts=[b] * shard_count) res_sharded = ops.matmul(a_sharded, b_sharded) assert isinstance(res_sharded, SplitPrimitiveTensor) assert res_sharded.shard_dim == 1 @@ -837,7 +837,7 @@ def testReplicateUnsharded(self): tensor = torch.rand(4, 5, dtype=torch.float32) shard_count = 3 actual_result = ops.replicate(tensor, count=shard_count) - expected_result = ReplicatedTensor(ts=tensor, shard_count=shard_count) + expected_result = ReplicatedTensor(ts=[tensor] * shard_count) assert expected_result.is_deep_equal(actual_result) # Test that is a copy. diff --git a/sharktank/tests/types/tensors_test.py b/sharktank/tests/types/tensors_test.py index 4af4a513f..37898a7c5 100644 --- a/sharktank/tests/types/tensors_test.py +++ b/sharktank/tests/types/tensors_test.py @@ -106,7 +106,7 @@ def testUnreducedTensorSaveLoad(self): def testReplicatedTensorExtractSlice(self): tensor = torch.rand([2, 3, 4], dtype=torch.float32) - replicated_tensor = ReplicatedTensor(ts=tensor, shard_count=3) + replicated_tensor = ReplicatedTensor(ts=[tensor] * 3) s = [slice(1, 2), slice(0, 3, 2), None] expected_result = tensor[s] replicated_sliced_tensor = replicated_tensor[s] @@ -116,7 +116,7 @@ def testReplicatedTensorExtractSlice(self): def testReplicatedTensorExtractElement(self): tensor = torch.rand([2, 3, 4], dtype=torch.float32) - replicated_tensor = ReplicatedTensor(ts=tensor, shard_count=3) + replicated_tensor = ReplicatedTensor(ts=[tensor] * 3) idx = ( 1, 2,