diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 900c1a9ae..945c418c3 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -18,11 +18,12 @@ # TODO: Should be using a base class with the protocol supported. from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 -from ..models.llama.sharding import shard_theta from ..models.mixtral.mixtral import * from ..models.grok.grok import * from .. import ops +from typing import Union, Sequence + def main(): from ..utils import cli @@ -55,6 +56,11 @@ def main(): help="Enables strictness during export", action="store_true", ) + parser.add_argument( + "--use-queue-affinities", + help="Enables queue affinities for multidevice", + action="store_true", + ) cli.add_quantization_options(parser) cli.add_model_options(parser) @@ -74,18 +80,44 @@ def main(): tensor_parallelism_size=tensor_parallelism_size, use_hf=False, static_tables=False, # Rely on the compiler for hoisting tables. - kv_cache_type="direct" if args.bs == [1] else "paged", + kv_cache_type="paged" if args.bs == [1] else "paged", attention_kernel=args.attention_kernel, ) llama_config.fake_quant = args.fake_quant + def setup_queue_affinities(sharding): + return [DeviceAffinity(0, [str(i)]) for i in range(sharding)] + + def assign_affinities(theta, affinities): + def transform(tensors: Union[None, InferenceTensor, Sequence[InferenceTensor]]): + if tensors is None: + return tensors + if isinstance(tensors, Sequence): + return [transform(t) for t in tensors] + if isinstance(tensors, ShardedTensor): + tensors.assign_affinities(affinities) + return tensors + if isinstance(tensors, InferenceTensor): + return tensors + + raise ValueError("Unknown device for reassigned affinities") + + return theta.transform(transform) + + affinities = [ + DeviceAffinity(i) for i in range(llama_config.tensor_parallelism_size) + ] + if args.use_queue_affinities: + affinities = setup_queue_affinities(llama_config.tensor_parallelism_size) + dataset.root_theta = assign_affinities(dataset.root_theta, affinities) + if llama_config.hp.expert_count: if llama_config.hp.model_arch == "grok": model = PagedGrokModelV1(dataset.root_theta, llama_config) else: model = PagedMixtralModelV1(dataset.root_theta, llama_config) else: - model = PagedLlamaModelV1(dataset.root_theta, llama_config) + model = PagedLlamaModelV1(dataset.root_theta, llama_config, affinities) def generate_params_json( hp: LlamaHParams, prefill_bs: list[int], decode_bs: list[int] @@ -121,7 +153,7 @@ def generate_params_json( fxb = FxProgramsBuilder(model) - def setup_cache(model, shard_count): + def setup_cache(model, affinities): if model.config.kv_cache_type == "paged": cache_state = model.cache.allocate( page_count=hp.context_length // llama_config.block_seq_stride @@ -143,24 +175,21 @@ def setup_cache(model, shard_count): ] for i in range(llama_config.tensor_parallelism_size): - arg_affinities[i] = DeviceAffinity(str(i)) + arg_affinities[i] = affinities[i] return unpacked, shard_dim, dynamic_shapes, arg_affinities elif model.config.kv_cache_type == "direct": - cache_state = model.cache.allocate(bs=1) - # Direct cache dimensions: - # 2 * transformer_block_count of... - # [bs, seq_length, attn_head_count, attn_head_dim] - dynamic_shapes = [None] - arg_affinities = {} - shard_dim = None - return torch.stack(cache_state), shard_dim, dynamic_shapes, arg_affinities + raise NotImplementedError(f"Direct cache is not currently functional") + else: raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}") def repack_cache(cache, shard_dim): - return [SplitPrimitiveTensor(ts=c, shard_dim=shard_dim) for c in cache] + return [ + SplitPrimitiveTensor(ts=c, shard_dim=shard_dim, devices=affinities) + for c in cache + ] def generate_batch_prefill(bs: int): # torch.export.Dim would make min at least 2 @@ -177,7 +206,7 @@ def generate_batch_prefill(bs: int): seq_lens = torch.empty(bs, dtype=torch.int64) cache, cache_shard_dim, cache_dynamic_shapes, arg_affinities = setup_cache( - model, llama_config.tensor_parallelism_size + model, affinities ) if llama_config.tensor_parallelism_size > 1: @@ -219,9 +248,9 @@ def _(model, tokens, seq_lens, seq_block_ids, cs): if llama_config.tensor_parallelism_size != 1: shard_count = llama_config.tensor_parallelism_size - tokens = ops.replicate(tokens, count=shard_count) - attention_mask = ops.replicate(attention_mask, count=shard_count) - seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) + tokens = ops.replicate(tokens, devices=affinities) + attention_mask = ops.replicate(attention_mask, devices=affinities) + seq_block_ids = ops.replicate(seq_block_ids, devices=affinities) cache_tensors = repack_cache(cs, cache_shard_dim) @@ -256,7 +285,7 @@ def generate_batch_decode(bs: int): cache_shard_dim, cache_dynamic_shapes, arg_affinities, - ) = setup_cache(model, llama_config.tensor_parallelism_size) + ) = setup_cache(model, affinities) if llama_config.tensor_parallelism_size > 1: # We need to offset the indices for the cache @@ -264,7 +293,7 @@ def generate_batch_decode(bs: int): # Inputs have default affinity 0 for i in range(4): - arg_affinities[i] = DeviceAffinity("0") + arg_affinities[i] = affinities[0] dynamic_shapes = { "tokens": {}, @@ -303,12 +332,10 @@ def _( attention_mask = model.decode_attention_mask(input_mask) if llama_config.tensor_parallelism_size != 1: - shard_count = llama_config.tensor_parallelism_size - - tokens = ops.replicate(tokens, count=shard_count) - attention_mask = ops.replicate(attention_mask, count=shard_count) - start_positions = ops.replicate(start_positions, count=shard_count) - seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) + tokens = ops.replicate(tokens, devices=affinities) + attention_mask = ops.replicate(attention_mask, devices=affinities) + start_positions = ops.replicate(start_positions, devices=affinities) + seq_block_ids = ops.replicate(seq_block_ids, devices=affinities) cache_state = repack_cache(cache_state, cache_shard_dim) @@ -327,7 +354,8 @@ def _( bsizes = [] for bs in args.bs: - generate_batch_prefill(bs) + if not args.skip_prefill: + generate_batch_prefill(bs) if not args.skip_decode: generate_batch_decode(bs) bsizes.append(bs) diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index b30acc026..cc1c96d8c 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -15,6 +15,8 @@ import torch +from iree.turbine.aot import DeviceAffinity + from ..layers import * from ..types import * @@ -156,9 +158,10 @@ def prefill(self): token_ids = self.token_ids if model.config.tensor_parallelism_size != 1: tp = model.config.tensor_parallelism_size - token_ids = replicate(token_ids, tp) - attention_mask = replicate(attention_mask, tp) - seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp) + devices = [DeviceAffinity(i) for i in range(tp)] + token_ids = replicate(token_ids, devices) + attention_mask = replicate(attention_mask, devices) + seq_block_ids_tensor = replicate(seq_block_ids_tensor, devices) logits = model.prefill( token_ids, @@ -199,10 +202,11 @@ def decode(self): if model.config.tensor_parallelism_size != 1: tp = model.config.tensor_parallelism_size - self.next_tokens = replicate(self.next_tokens, tp) - start_positions = replicate(start_positions, tp) - seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp) - decode_attention_mask = replicate(decode_attention_mask, tp) + devices = [DeviceAffinity(i) for i in range(tp)] + self.next_tokens = replicate(self.next_tokens, devices) + start_positions = replicate(start_positions, devices) + seq_block_ids_tensor = replicate(seq_block_ids_tensor, devices) + decode_attention_mask = replicate(decode_attention_mask, devices) logits = model.decode( self.next_tokens, @@ -279,8 +283,7 @@ def main(): tensor_parallelism_size=args.tensor_parallelism_size, fake_quant=args.fake_quant, ) - if config.tensor_parallelism_size > 1: - dataset.root_theta = shard_theta(dataset.root_theta, config) + affinities = [DeviceAffinity(i) for i in range(args.tensor_parallelism_size)] if config.hp.expert_count: if config.hp.model_arch == "grok": @@ -288,7 +291,7 @@ def main(): else: model = PagedMixtralModelV1(dataset.root_theta, config) else: - model = PagedLlamaModelV1(dataset.root_theta, config) + model = PagedLlamaModelV1(dataset.root_theta, config, devices=affinities) if args.save_intermediates_path: from ..utils.patching import SaveModuleResultTensorsPatch diff --git a/sharktank/sharktank/examples/sharding/export_ffn_net.py b/sharktank/sharktank/examples/sharding/export_ffn_net.py index f261a92e1..7d1b1a2be 100644 --- a/sharktank/sharktank/examples/sharding/export_ffn_net.py +++ b/sharktank/sharktank/examples/sharding/export_ffn_net.py @@ -17,6 +17,8 @@ import torch import torch.nn as nn +from iree.turbine.aot import DeviceAffinity + from ...layers import * from ... import ops from ...types import * @@ -50,7 +52,9 @@ def forward(self, x: torch.Tensor): ffn_gate_weight = self.theta.tensor("ffn_gate", "weight") ffn_up_weight = self.theta.tensor("ffn_up", "weight") ffn_down_weight = self.theta.tensor("ffn_down", "weight") - x = ops.replicate(x, count=ffn_gate_weight.shard_count) + + devices = [DeviceAffinity(i) for i in range(ffn_down_weight.shard_count)] + x = ops.replicate(x, devices=devices) ffn_gate = ops.elementwise( torch.nn.functional.silu, ops.linear(x, ffn_gate_weight) ) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index 46e94ff90..82973b576 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -292,7 +292,9 @@ def unflatten_page_table( shards = [ shard.unflatten(1, self.sub_page_dims) for shard in page_slab.shards ] - return SplitPrimitiveTensor(ts=shards, shard_dim=4) + return SplitPrimitiveTensor( + ts=shards, shard_dim=4, devices=page_slab.devices + ) def shard_state( self, state: List[torch.Tensor] diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index 99ecf5057..3c1411fce 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -27,10 +27,11 @@ def __init__( use_hf: bool = False, static_tables: bool = False, use_table: bool = True, - tensor_parallelism_size: int = 1, + devices: list | None = None, ): super().__init__() self.device = device + self.devices = devices self.rope_dimension_count = rope_dimension_count self.max_seqlen = max_seqlen self.use_hf = use_hf @@ -38,7 +39,7 @@ def __init__( self.use_table = use_table self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0 - self.tensor_parallelism_size = tensor_parallelism_size + self.devices = devices if static_tables: ops.module_register_buffer( self, "static_rotary_embed_table", self._create_rotary_embed_table() @@ -80,7 +81,9 @@ def forward( ) for xt_shard, rotary_shard in zip(xt.shards, rotary_shards) ] - xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim) + xt = SplitPrimitiveTensor( + ts=xt_shards, shard_dim=xt.shard_dim, devices=xt.devices + ) return xt else: return self.forward_unsharded( @@ -189,7 +192,7 @@ def compute_batch_mask( self._compute_rotary_embed_table(s.flatten()).unflatten(0, shape) for s in positions_seq.shards ] - freqs_cis = ReplicatedTensor(ts=ts) + freqs_cis = ReplicatedTensor(ts=ts, devices=positions_seq.devices) else: freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten()) freqs_cis = freqs_cis.unflatten(0, shape) @@ -215,7 +218,9 @@ def apply_batched_mask( ) for xt_shard, mask_shard in zip(xt.shards, mask.shards) ] - xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim) + xt = SplitPrimitiveTensor( + ts=xt_shards, shard_dim=xt.shard_dim, devices=xt.devices + ) return xt def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor): @@ -259,8 +264,8 @@ def _create_rotary_embed_table(self): return self._replicate(freqs_cis) def _replicate(self, t): - if self.tensor_parallelism_size > 1: + if self.devices is not None and len(self.devices) > 1: # Replicate across all devices, the data is not a lot and the computation is cheap. - t = ops.replicate(t, self.tensor_parallelism_size) + t = ops.replicate(t, self.devices) return t diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 0a9a6f1c3..367082f48 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -9,6 +9,8 @@ from dataclasses import dataclass from typing import Union +from iree.turbine.aot import DeviceAffinity + import torch import torch.nn as nn import torch.nn.functional as F @@ -62,7 +64,7 @@ class PagedLlamaModelV1(BaseCausalLMModel): unsharded result or chain it with other tensor-parallel operations. """ - def __init__(self, theta: Theta, config: LlamaModelConfig): + def __init__(self, theta: Theta, config: LlamaModelConfig, devices: list = None): hp = config.hp super().__init__( theta, @@ -80,6 +82,9 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): self.use_hf = config.use_hf self.attention_kernel = config.attention_kernel + if devices is None: + devices = [DeviceAffinity(i) for i in range(config.tensor_parallelism_size)] + self.add_module( "token_embedding", TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype), @@ -91,9 +96,9 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): rope_freq_base=hp.rope_freq_base, max_seqlen=hp.context_length, device=self.device, + devices=devices, use_hf=self.use_hf, static_tables=config.static_tables, - tensor_parallelism_size=config.tensor_parallelism_size, ), ) self.add_module( @@ -238,8 +243,12 @@ def decode( ) for _ in range(self.config.tensor_parallelism_size) ] - xk_temp = SplitPrimitiveTensor(ts=xk_temp_shard, shard_dim=2) - xv_temp = SplitPrimitiveTensor(ts=xv_temp_shard, shard_dim=2) + xk_temp = SplitPrimitiveTensor( + ts=xk_temp_shard, shard_dim=2, devices=tokens.devices + ) + xv_temp = SplitPrimitiveTensor( + ts=xv_temp_shard, shard_dim=2, devices=tokens.devices + ) h = self.token_embedding(tokens) self.trace_tensor("llama.token_embedding", h) diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index 015e88a4b..00ebd0138 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -12,6 +12,8 @@ import math import functools +from iree.turbine.aot import DeviceAffinity + from ..types import ( AnyTensor, DefaultPrimitiveTensor, @@ -42,14 +44,14 @@ def all_gather_split( shards = [ cat( [ - shard if i == j else transfer_to_logical_device(shard, i) - for j, shard in enumerate(input.shards) + transfer_to_logical_device(shard, device.ordinal) + for shard in input.shards ], dim=dim, ) - for i in range(input.shard_count) + for device in input.devices ] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=input.devices) @all_reduce.override(AllOfType(SplitPrimitiveTensor, UnreducedTensor)) @@ -63,23 +65,25 @@ def all_reduce_split_or_unreduced( functools.reduce( lambda x, y: elementwise(torch.add, x, y), [ - shard if i == j else transfer_to_logical_device(shard, i) - for j, shard in enumerate(input.shards) + transfer_to_logical_device(shard, device.ordinal) + for shard in input.shards ], ) - for i in range(input.shard_count) + for device in input.devices ] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=input.devices) @cat.override(AllOfType(ReplicatedTensor)) def cat_replicated(tensors: Sequence[ReplicatedTensor], dim: int) -> ReplicatedTensor: assert len(tensors) > 0 + devices = tensors[0].devices shard_count = tensors[0].shard_count assert all([t.shard_count == shard_count for t in tensors]) + assert all([t.devices == devices for t in tensors]) shards = [cat(shards, dim) for shards in zip(*[t.shards for t in tensors])] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=devices) @cat.override(AllOfType(SplitPrimitiveTensor)) @@ -97,9 +101,12 @@ def cat_split( shard_dim = tensors[0].shard_dim shard_count = tensors[0].shard_count + devices = tensors[0].devices + for t in tensors: + assert t.devices == tensors[0].devices if dim != shard_dim: shards = [cat(shards, dim) for shards in zip(*[t.shards for t in tensors])] - return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim, devices=devices) else: # TODO: implement efficient cat along split dim. # This would probably result in doing the concatenation on one device. @@ -191,6 +198,7 @@ def conv2d_replicated_input_split_weight_and_bias( bias is None or weight.shard_dim == 0 and bias.shard_dim == 0 ), "Only sharding of output channel dimension is supported" assert groups == 1 + assert input.devices == weight.devices shards = [ conv2d( @@ -208,7 +216,7 @@ def conv2d_replicated_input_split_weight_and_bias( [None] * weight.shard_count if bias is None else bias.shards, ) ] - return SplitPrimitiveTensor(shard_dim=1, ts=shards) + return SplitPrimitiveTensor(shard_dim=1, ts=shards, devices=input.devices) conv2d.override( @@ -252,7 +260,7 @@ def conv2d_split_weight_and_bias( [None] * weight.shard_count if bias is None else bias.shards, ) ] - return SplitPrimitiveTensor(shard_dim=1, ts=shards) + return SplitPrimitiveTensor(shard_dim=1, ts=shards, devices=weight.devices) else: assert False, "Unsupported, TODO: handle split channels in input" @@ -270,13 +278,15 @@ def conv2d_split_weight_and_bias( @elementwise.override(ReplicatedTensor) def replicated_elementwise_unary(operator, x: ReplicatedTensor, *args, **kwargs): partials = [operator(unbox_tensor(pt), *args, **kwargs) for pt in x.shards] - return ReplicatedTensor(ts=partials) + return ReplicatedTensor(ts=partials, devices=x.devices) @elementwise.override(SplitPrimitiveTensor) def split_elementwise_unary(operator, x: SplitPrimitiveTensor, *args, **kwargs): partials = [operator(unbox_tensor(pt), *args, **kwargs) for pt in x.shards] - return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials) + return SplitPrimitiveTensor( + shard_dim=x.shard_dim, shape=x.shape, ts=partials, devices=x.devices + ) @elementwise.override(ReplicatedTensor, ReplicatedTensor) @@ -284,11 +294,12 @@ def replicated_elementwise_binary( operator, x: ReplicatedTensor, y: ReplicatedTensor, *args, **kwargs ): assert x.shard_count == y.shard_count + assert x.devices == y.devices shards = [ operator(unbox_tensor(shard_x), unbox_tensor(shard_y), *args, **kwargs) for shard_x, shard_y in zip(x.shards, y.shards) ] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=x.devices) @elementwise.override(SplitPrimitiveTensor, SplitPrimitiveTensor) @@ -298,6 +309,7 @@ def split_elementwise_binary( assert x.shard_count == y.shard_count x_shard_dim, y_shard_dim = broadcast_dims([x.shard_dim, y.shard_dim], [x, y]) assert x_shard_dim == y_shard_dim + assert x.devices == y.devices pt_xs = [unbox_tensor(pt) for pt in x.shards] pt_ys = [unbox_tensor(pt) for pt in y.shards] partials = [ @@ -307,6 +319,7 @@ def split_elementwise_binary( shard_dim=x.shard_dim, shape=torch.broadcast_shapes(x.shape, y.shape), ts=partials, + devices=x.devices, ) @@ -316,7 +329,9 @@ def elementwise_binary_split_lhs_scalar_rhs( ): pt_xs = [unbox_tensor(pt) for pt in x.shards] partials = [operator(pt_x, y, *args, **kwargs) for pt_x in pt_xs] - return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials) + return SplitPrimitiveTensor( + shard_dim=x.shard_dim, shape=x.shape, ts=partials, devices=x.devices + ) @elementwise.override(SplitPrimitiveTensor, Tensor) @@ -345,6 +360,7 @@ def elementwise_binary_split_lhs_replicated_rhs( operator, x: SplitPrimitiveTensor, y: ReplicatedTensor, *args, **kwargs ): assert len(y.shape) > 0, "0-rank not supported" + assert x.devices == y.devices if x.shard_count != y.shard_count: raise ValueError( f"Operands' number of shards not equal ({x.shard_count} != {y.shard_count})" @@ -360,7 +376,9 @@ def elementwise_binary_split_lhs_replicated_rhs( elementwise(operator, x_shard, y_shard) for x_shard, y_shard in zip(x.shards, y.shards) ] - return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim_in_res) + return SplitPrimitiveTensor( + ts=shards, shard_dim=shard_dim_in_res, devices=x.devices + ) y_sharded = reshard_like(y, like=x) return elementwise(operator, x, y_sharded, *args, **kwargs) @@ -402,13 +420,14 @@ def embedding_lookup_default( dtype: Optional[torch.dtype], ): assert input.shard_count == embedding_matrix.shard_count + assert input.devices == embedding_matrix.devices shards = [ embedding_lookup(input_shard, embedding_matrix_shard, dtype) for input_shard, embedding_matrix_shard in zip( input.shards, embedding_matrix.shards ) ] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=input.devices) @equal.override(ReplicatedTensor) @@ -446,7 +465,9 @@ def set_element(l: List, idx: int, el: Any) -> List: ) for shard in tensor.shards ] - return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=tensor.shard_dim, devices=tensor.devices + ) @flatten.override(ReplicatedTensor) @@ -454,7 +475,7 @@ def flatten_replicated( input: ReplicatedTensor, start_dim: int, end_dim: int ) -> ReplicatedTensor: shards = [shard.flatten(start_dim, end_dim) for shard in input.shards] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=input.devices) @flatten.override(SplitPrimitiveTensor) @@ -477,7 +498,7 @@ def flatten_split( if input.shard_dim <= start_dim else input.shard_dim - (end_dim_resolved - start_dim) ) - return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim, devices=input.devices) @gather.override(ReplicatedTensor, ReplicatedTensor) @@ -485,11 +506,12 @@ def gather_replicated( input: ReplicatedTensor, dim: int, index: ReplicatedTensor ) -> Tensor: assert input.shard_count == index.shard_count + assert input.devices == index.devices shards = [ gather(input_shard, dim, index_shard) for input_shard, index_shard in zip(input.shards, index.shards) ] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=input.devices) @group_norm_affine.override( @@ -509,7 +531,7 @@ def shareded_group_norm_affine(input, weight, bias, *, num_groups, eps): for x, w, b in zip(input.shards, weight.shards, bias.shards) ] - return SplitPrimitiveTensor(shard_dim=1, ts=result_shards) + return SplitPrimitiveTensor(shard_dim=1, ts=result_shards, devices=input.devices) @index_copy_.override(SplitPrimitiveTensor, ReplicatedTensor, SplitPrimitiveTensor) @@ -546,7 +568,7 @@ def index_put__split( ) -> SplitPrimitiveTensor: # TODO: verify that the values split dimension is not being indexed or implement # this case. - indices = [replicate(idx, count=inout.shard_count) for idx in indices] + indices = [replicate(idx, devices=inout.devices) for idx in indices] for i, shard in enumerate(inout.shards): shard_indices = [idx.shards[i] for idx in indices] shard.index_put_(shard_indices, values.shards[i]) @@ -560,11 +582,12 @@ def index_select_replicated( index: ReplicatedTensor, ) -> ReplicatedTensor: assert tensor.shard_count == index.shard_count + assert tensor.devices == index.devices shards = [ index_select(tensor_shard, dim, index_shard) for tensor_shard, index_shard in zip(tensor.shards, index.shards) ] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=tensor.devices) @index_select.override(SplitPrimitiveTensor, ReplicatedTensor) @@ -581,7 +604,9 @@ def index_select_split_replicated( index_select(tensor_shard, dim, index_shard) for tensor_shard, index_shard in zip(tensor.shards, index.shards) ] - return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=tensor.shard_dim, devices=tensor.devices + ) @interpolate.override(ReplicatedTensor) @@ -606,7 +631,7 @@ def interpolate_replicated( ) for shard in input.shards ] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=input.devices) @interpolate.override(SplitPrimitiveTensor) @@ -632,7 +657,9 @@ def interpolate_split_batch_or_channel( ) for shard in input.shards ] - return SplitPrimitiveTensor(ts=shards, shard_dim=input.shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=input.shard_dim, devices=input.devices + ) @layer_norm.override(SplitPrimitiveTensor, Tensor, Tensor) @@ -641,7 +668,9 @@ def layer_norm_default(input, weight, bias, *, eps): weight.shape ) shards = [layer_norm(shard, weight, bias, eps=eps) for shard in input.shards] - return SplitPrimitiveTensor(shard_dim=input.shard_dim, ts=shards) + return SplitPrimitiveTensor( + shard_dim=input.shard_dim, ts=shards, devices=input.devices + ) # Linear @@ -695,7 +724,9 @@ def matmul_replicated_lhs_split_rhs( matmul(lhs_shard, rhs_shard) for (lhs_shard, rhs_shard) in zip(lhs.shards, rhs.shards) ] - return SplitPrimitiveTensor(ts=shards, shard_dim=len(lhs.shape) - 2 + rhs.shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=len(lhs.shape) - 2 + rhs.shard_dim, devices=rhs.devices + ) @matmul.override(SplitPrimitiveTensor, Tensor) @@ -707,7 +738,7 @@ def matmul_split_lhs( shards = [ matmul(lhs_shard, rhs, transpose_rhs=transpose_rhs) for lhs_shard in lhs.shards ] - return SplitPrimitiveTensor(shard_dim=lhs.shard_dim, ts=shards) + return SplitPrimitiveTensor(shard_dim=lhs.shard_dim, ts=shards, devices=lhs.devices) @matmul.override(Tensor, SplitPrimitiveTensor) @@ -732,7 +763,9 @@ def matmul_split_rhs( for partial_rhs in rhs.shards ] # The partial is split columnwise (last dim). - return SplitPrimitiveTensor(shard_dim=len(lhs.shape) - 1, ts=partials) + return SplitPrimitiveTensor( + shard_dim=len(lhs.shape) - 1, ts=partials, devices=rhs.devices + ) @matmul.override(SplitPrimitiveTensor, ReplicatedTensor) @@ -747,13 +780,14 @@ def matmul_split_lhs_replicated_rhs( matmul(lhs_shard, rhs_shard) for (lhs_shard, rhs_shard) in zip(lhs.shards, rhs.shards) ] - return SplitPrimitiveTensor(ts=shards, shard_dim=lhs.shard_dim) + return SplitPrimitiveTensor(ts=shards, shard_dim=lhs.shard_dim, devices=lhs.devices) @matmul.override(SplitPrimitiveTensor, SplitPrimitiveTensor) def matmul_split( lhs: SplitPrimitiveTensor, rhs: SplitPrimitiveTensor, *, transpose_rhs: bool ) -> UnreducedTensor | SplitPrimitiveTensor: + assert lhs.devices == rhs.devices if lhs.shard_count != rhs.shard_count: raise ValueError( f"Cannot matmul split tensors of different shard_count: " @@ -771,7 +805,7 @@ def matmul_split( matmul(partial_lhs, partial_rhs) for partial_lhs, partial_rhs in zip(lhs.shards, rhs.shards) ] - return UnreducedTensor(ts=partials) + return UnreducedTensor(ts=partials, devices=lhs.devices) is_batched_matmul = len(lhs.shape) > 2 or len(rhs.shape) > 2 if ( @@ -784,7 +818,9 @@ def matmul_split( matmul(lhs_shard, rhs_shard) for lhs_shard, rhs_shard in zip(lhs.shards, rhs.shards) ] - return SplitPrimitiveTensor(ts=shards, shard_dim=lhs.shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=lhs.shard_dim, devices=lhs.devices + ) # -1 for missing parallel dim. lhs_parallel_dim = len(lhs.shape) - 2 @@ -812,6 +848,8 @@ def matmul_split( Optional[ReplicatedTensor], ) def scaled_dot_product_attention_sharded(q, k, v, a, is_causal, scale) -> Tensor: + assert q.devices == k.devices + assert q.devices == v.devices if q.shard_count != k.shard_count or q.shard_count != v.shard_count: raise ValueError("Incompatible number of shards for qkv") @@ -837,7 +875,9 @@ def scaled_dot_product_attention_sharded(q, k, v, a, is_causal, scale) -> Tensor ) output_shards.append(o_s) - return SplitPrimitiveTensor(ts=output_shards, shard_dim=q.shard_dim) + return SplitPrimitiveTensor( + ts=output_shards, shard_dim=q.shard_dim, devices=q.devices + ) @mean.override(ReplicatedTensor) @@ -849,7 +889,7 @@ def mean_replicated( dtype: torch.dtype, ) -> None: shards = [mean(shard, dim=dim, keepdim=keepdim, dtype=dtype) for shard in x.shards] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=x.devices) @module_register_buffer.override(torch.nn.Module, ShardedTensor) @@ -865,48 +905,80 @@ def module_register_buffer_sharded( def permute_split(tensor: SplitPrimitiveTensor, dims: List[int]): permuted_shards = [permute(shard, dims) for shard in tensor.shards] permuted_shard_dim = dims[tensor.shard_dim] - return SplitPrimitiveTensor(ts=permuted_shards, shard_dim=permuted_shard_dim) + return SplitPrimitiveTensor( + ts=permuted_shards, shard_dim=permuted_shard_dim, devices=tensor.devices + ) @permute.override(ReplicatedTensor) def permute_replicated(tensor: ReplicatedTensor, dims: List[int]): permuted_shards = [permute(shard, dims) for shard in tensor.shards] - return ReplicatedTensor(ts=permuted_shards) + return ReplicatedTensor(ts=permuted_shards, devices=tensor.devices) @repeat.override(ReplicatedTensor) def repeat_replicated(input: ReplicatedTensor, *sizes: List[int]) -> ReplicatedTensor: shards = [repeat(shard, *sizes) for shard in input.shards] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=input.devices) @replicate.override(ReplicatedTensor) -def replicate_replicated(input: ReplicatedTensor, *, count: int) -> ReplicatedTensor: - if input.shard_count != count: +def replicate_replicated( + input: ReplicatedTensor, *, devices: list, count: int +) -> ReplicatedTensor: + if devices is not None and input.shard_count != len(devices): + raise ValueError( + f"Number of shards not equal ({input.shard_count} != {len(devices)})" + ) + if count is not None and input.shard_count != count: raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") return input @replicate.override(SplitPrimitiveTensor) -def replicate_split(input: SplitPrimitiveTensor, *, count: int) -> ReplicatedTensor: - if input.shard_count != count: +def replicate_split( + input: SplitPrimitiveTensor, *, devices: list, count: int +) -> ReplicatedTensor: + if devices is not None and input.shard_count != len(devices): + raise ValueError( + f"Number of shards not equal ({input.shard_count} != {len(devices)})" + ) + if count is not None and input.shard_count != count: raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") return all_gather(input) @replicate.override(UnreducedTensor) -def replicate_unreduced(input: UnreducedTensor, *, count: int) -> ReplicatedTensor: - if input.shard_count != count: +def replicate_unreduced( + input: UnreducedTensor, *, devices: list, count: int +) -> ReplicatedTensor: + if devices is not None and input.shard_count != len(devices): + raise ValueError( + f"Number of shards not equal ({input.shard_count} != {len(devices)})" + ) + if count is not None and input.shard_count != count: raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") return all_reduce(input) @replicate.override(Tensor) -def replicate_unsharded(input, *, count: int) -> ReplicatedTensor: +def replicate_unsharded(input, *, devices: list, count: int) -> ReplicatedTensor: torch_input = unbox_tensor(input) # If we have a torch input replicating we can assume we need to transfer: - torch_inputs = [transfer_to_logical_device(torch_input, i) for i in range(count)] - return ReplicatedTensor(ts=torch_inputs) + if devices is not None: + torch_inputs = [ + transfer_to_logical_device(torch_input, d.ordinal) for d in devices + ] + return ReplicatedTensor(ts=torch_inputs, devices=devices) + + if count is not None: + devices = [DeviceAffinity(i) for i in range(count)] + torch_inputs = [ + transfer_to_logical_device(torch_input, i) for i in range(count) + ] + return ReplicatedTensor(ts=torch_inputs, devices=devices) + + raise ValueError(f"Devices or count is required") @reshape.override(SplitPrimitiveTensor) @@ -975,7 +1047,8 @@ def reshard_all_to_unsharded(input: AnyTensor, spec: sharding.Unsharded) -> Tens def reshard_all_to_replicated( input: AnyTensor, spec: sharding.Replicated ) -> ReplicatedTensor: - return replicate(input, spec.shard_count) + device = [DeviceAffinity(i) for i in range(spec.shard_count)] + return replicate(input, device) @reshard_split.override(Tensor) @@ -1018,7 +1091,7 @@ def slice_range_along_dim(dim: int, start: int, end: int): ] for shard_idx, shard in enumerate(input.shards) ] - return SplitPrimitiveTensor(ts=shards, shard_dim=dim) + return SplitPrimitiveTensor(ts=shards, shard_dim=dim, devices=input.devices) @reshard_like.override(Tensor, SplitPrimitiveTensor) @@ -1044,7 +1117,7 @@ def reshard_like_unsharded_to_replicated( tensor, like: ReplicatedTensor ) -> ReplicatedTensor: torch_tensor = unbox_tensor(tensor) - return replicate(torch_tensor, count=like.shard_count) + return replicate(torch_tensor, devices=like.devices) @reshard_like.override(ReplicatedTensor, ReplicatedTensor) @@ -1079,14 +1152,14 @@ def reshard_like_split_to_split( def reshard_like_unreduced_to_replicated( tensor: UnreducedTensor, like: ReplicatedTensor ) -> ReplicatedTensor: - return replicate(tensor, count=like.shard_count) + return replicate(tensor, devices=like.devices) @sharded_cat.override(SplitPrimitiveTensor) def sharded_cat_unsharded(tensor: SplitPrimitiveTensor): shard_ts = [ - transfer_to_logical_device(shard.as_torch(), 0) if i != 0 else shard.as_torch() - for i, shard in enumerate(tensor.shards) + transfer_to_logical_device(shard.as_torch(), tensor.devices[0].ordinal) + for shard in tensor.shards ] return torch.cat(shard_ts, dim=tensor.shard_dim) @@ -1122,20 +1195,25 @@ def softmax_split( ), "Softmax along split dimension is not supported." shards = [softmax(shard, dim=dim, dtype=dtype) for shard in tensor.shards] return SplitPrimitiveTensor( - ts=shards, shard_dim=tensor.shard_dim, shape=tensor.shape + ts=shards, + shard_dim=tensor.shard_dim, + shape=tensor.shape, + devices=tensor.devices, ) @to.override(ReplicatedTensor) def to_replicated(tensor: ReplicatedTensor, *args, **kwargs): shards = [to(shard, *args, **kwargs) for shard in tensor.shards] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=tensor.devices) @to.override(SplitPrimitiveTensor) def to_split(tensor: SplitPrimitiveTensor, *args, **kwargs): shards = [to(shard, *args, **kwargs) for shard in tensor.shards] - return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=tensor.shard_dim, devices=tensor.devices + ) @transpose.override(SplitPrimitiveTensor) @@ -1148,7 +1226,7 @@ def transpose_split( shard_dim = dim1 elif shard_dim == dim1: shard_dim = dim0 - return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim, devices=tensor.devices) @unflatten.override(SplitPrimitiveTensor) @@ -1160,7 +1238,7 @@ def unflatten_split( shard_dim = input.shard_dim if dim < shard_dim: shard_dim += len(sizes) - 1 - return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim, devices=input.devices) @unshard.override(ReplicatedTensor) @@ -1177,8 +1255,7 @@ def unshard_split(input: SplitPrimitiveTensor) -> Tensor: def unshard_unreduced(input: UnreducedTensor) -> Tensor: shards = input.shards shards = [ - shard if i == 0 else transfer_to_logical_device(shard, 0) - for i, shard in enumerate(shards) + transfer_to_logical_device(shard, input.devices[0].ordinal) for shard in shards ] return functools.reduce(lambda x, y: elementwise(torch.add, x, y), shards) @@ -1268,13 +1345,13 @@ def unsqueeze_split(tensor: SplitPrimitiveTensor, dim: int) -> SplitPrimitiveTen dim_resolved = dim if dim >= 0 else dim + len(tensor.shape) + 1 if shard_dim >= dim_resolved: shard_dim += 1 - return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim, devices=tensor.devices) @unsqueeze.override(ReplicatedTensor) def unsqueeze_replicated(tensor: ReplicatedTensor, dim: int) -> SplitPrimitiveTensor: shards = [torch.unsqueeze(unbox_tensor(shard), dim) for shard in tensor.shards] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=tensor.devices) @view.override(SplitPrimitiveTensor) @@ -1303,7 +1380,7 @@ def view_split(tensor: SplitPrimitiveTensor, shape: List[int]) -> SplitPrimitive new_shard_shape = list(shape) new_shard_shape[shard_dim] //= tensor.shard_count shards = [view(shard, new_shard_shape) for shard in tensor.shards] - res = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards) + res = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards, devices=tensor.devices) assert math.prod(res.shape) == math.prod(tensor.shape) return res @@ -1311,22 +1388,26 @@ def view_split(tensor: SplitPrimitiveTensor, shape: List[int]) -> SplitPrimitive @view_as_complex.override(SplitPrimitiveTensor) def view_as_complex_split(tensor: SplitPrimitiveTensor) -> SplitPrimitiveTensor: shards = [view_as_complex(shard) for shard in tensor.shards] - return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=tensor.shard_dim, devices=tensor.devices + ) @view_as_complex.override(ReplicatedTensor) def view_as_complex_rep(tensor: ReplicatedTensor) -> ReplicatedTensor: shards = [view_as_complex(shard) for shard in tensor.shards] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=tensor.devices) @view_as_real.override(SplitPrimitiveTensor) def view_as_real_split(tensor: SplitPrimitiveTensor) -> SplitPrimitiveTensor: shards = [view_as_real(shard) for shard in tensor.shards] - return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=tensor.shard_dim, devices=tensor.devices + ) @view_as_real.override(ReplicatedTensor) def view_as_real_rep(tensor: ReplicatedTensor) -> ReplicatedTensor: shards = [view_as_real(shard) for shard in tensor.shards] - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=tensor.devices) diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index cbe959d28..c5dadb698 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -796,7 +796,9 @@ def _repeat_trampoline( @overridable -def replicate(input: AnyTensor, count: int) -> ShardedTensor: +def replicate( + input: AnyTensor, devices: list = None, count: int = None +) -> ShardedTensor: """Replicate across devices. Possibly reshards if required.""" @@ -805,11 +807,11 @@ def replicate(input: AnyTensor, count: int) -> ShardedTensor: @replicate.trampoline def _replicate_trampoline( - d: SignatureDispatcher, input: AnyTensor, count: int + d: SignatureDispatcher, input: AnyTensor, devices: list = None, count: int = None ) -> ShardedTensor: tensors = (input,) for override in d.find_overrides(tensors): - result = override(input, count=count) + result = override(input, devices=devices, count=count) if result is not NotImplemented: return override, result else: diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 153a5d753..c4fad275f 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -28,8 +28,10 @@ from torch import Tensor from torch.utils._pytree import register_pytree_node, SequenceKey import torch.utils._pytree + from ..utils.math import ceildiv from iree.turbine.aot import ( + DeviceAffinity, DeviceTensorTrait, ExternalTensorTrait, ) @@ -791,6 +793,7 @@ def __init__( ts: list[torch.Tensor], name: str = UnnamedTensorName, shape: Optional[list[int]], + devices: Optional[list], ): assert len(ts) > 0 assert shard_dim is None or (shard_dim >= 0 and len(ts[0].shape) > shard_dim) @@ -805,6 +808,19 @@ def __init__( for i, t in enumerate(ts) ) + if devices is None: + devices = [DeviceAffinity(i) for i in range(len(self._shards))] + + self._devices: tuple[DeviceAffinity] = tuple(devices) + + def assign_affinities(self, affinities): + assert len(affinities) == len(self._devices) + self._devices = tuple(affinities) + for s, d in zip(self._shards, self._devices): + if isinstance(s, DefaultPrimitiveTensor): + s = s.as_torch() + DeviceTensorTrait(d.ordinal, d.queues).set(s) + @property def shard_count(self) -> int: return len(self._shards) @@ -813,6 +829,10 @@ def shard_count(self) -> int: def shards(self) -> tuple[InferenceTensor]: return self._shards + @property + def devices(self) -> tuple[DeviceAffinity]: + return self._devices + @property def is_replicated(self) -> bool: return False @@ -870,9 +890,8 @@ def create( t_name = str(i) try: t = raw_tensors[t_name] - ts.append(t) - # TODO: this should be changed to tracked device affinity DeviceTensorTrait(i).set(t) + ts.append(t) except KeyError as e: raise IOError( f"Missing component tensor '{t_name}' in {raw_tensors.keys()}" @@ -965,6 +984,7 @@ def __init__( shard_count: None | int = None, name: str = UnnamedTensorName, shape: Optional[list[int]] = None, + devices: list | None = None, ): """ If `ts` is a list of tensors, it is interpreted as the shards. @@ -973,16 +993,25 @@ def __init__( will be split along dimension `shard_dim` into `shard_count` number of pieces. """ - if isinstance(ts, torch.Tensor): - from ..ops import transfer_to_logical_device + assert shard_count is None or isinstance(ts, torch.Tensor) + shard_count = shard_count if shard_count is not None else len(ts) + + if devices is None: + devices = [DeviceAffinity(i) for i in range(shard_count)] + + assert len(devices) == shard_count + + if isinstance(ts, torch.Tensor): assert shard_count is not None ts = ts.split(ceildiv(ts.shape[shard_dim], shard_count), dim=shard_dim) - ts = [transfer_to_logical_device(t, i) for i, t in enumerate(ts)] + + from ..ops import transfer_to_logical_device + + ts = [transfer_to_logical_device(t, d.ordinal) for t, d in zip(ts, devices)] assert len(ts) == shard_count shard_count = None - assert shard_count is None assert len(ts) > 0 first_shape = ts[0].shape assert len(first_shape) > shard_dim @@ -1004,7 +1033,9 @@ def __init__( s == t for i, (s, t) in enumerate(zip(shape, t_shape)) if i != shard_dim ), f"Shape mismatch for non-split dimension for tensor shard {i} with shape {t.shape}" - super().__init__(name=name, ts=ts, shape=shape, shard_dim=shard_dim) + super().__init__( + name=name, ts=ts, shape=shape, shard_dim=shard_dim, devices=devices + ) def _is_slicing_split_dim(self, key): if isinstance( @@ -1072,7 +1103,9 @@ def __getitem__(self, key): # Rank reduction dimension before the split dim. shard_dim -= 1 - return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + return SplitPrimitiveTensor( + ts=shards, shard_dim=shard_dim, devices=self.devices + ) def __setitem__(self, key, value): assert isinstance(value, SplitPrimitiveTensor) @@ -1098,7 +1131,8 @@ class ReplicatedTensor(ShardedTensor): def __init__( self, *, - ts: list[torch.Tensor] | torch.Tensor, + ts: list[torch.Tensor], + devices: None | list[DeviceAffinity] = None, shard_count: None | int = None, name: str = UnnamedTensorName, ): @@ -1108,13 +1142,7 @@ def __init__( If `ts` is a tensor then `shard_count` must be provided and it, will be replicated that many times. """ - if isinstance(ts, torch.Tensor): - assert shard_count is not None - from ..ops import transfer_to_logical_device - - ts = [transfer_to_logical_device(ts, i) for i in range(shard_count)] - shard_count = None - + assert not isinstance(ts, torch.Tensor) assert shard_count is None assert len(ts) > 0 first_shape = ts[0].shape @@ -1134,6 +1162,19 @@ def __init__( for i, t in enumerate(ts) ) + if devices is None: + devices = tuple([DeviceAffinity(i) for i in range(len(ts))]) + + self._devices: tuple[DeviceAffinity] = tuple(devices) + + def assign_affinities(self, affinities): + assert len(affinities) == len(self._devices) + self._devices = tuple(affinities) + for s, d in zip(self._shards, self._devices): + if isinstance(s, DefaultPrimitiveTensor): + s = s.as_torch() + DeviceTensorTrait(d.ordinal, d.queues).set(s) + @property def shard_count(self) -> int: return len(self._shards) @@ -1142,6 +1183,10 @@ def shard_count(self) -> int: def shards(self) -> tuple[InferenceTensor]: return self._shards + @property + def devices(self) -> tuple[DeviceAffinity]: + return self._devices + @property def is_replicated(self) -> bool: return True @@ -1188,9 +1233,8 @@ def create( nt = deepcopy(t) ts.append(nt) - # TODO This should be changed to assigned affinities - for i in range(shard_count): - DeviceTensorTrait(i).set(ts[i]) + for i, t in enumerate(ts): + DeviceTensorTrait(i).set(t) except KeyError as e: raise IOError(f"Missing component tensor '' in {raw_tensors.keys()}") from e @@ -1210,12 +1254,13 @@ def __getitem__(self, key): else: shard_keys.append(k) shards.append(shard[*shard_keys]) - return ReplicatedTensor(ts=shards) + return ReplicatedTensor(ts=shards, devices=self.devices) def __repr__(self): return ( f"ReplicatedTensor({self.name}, {self.shape}, " f"shard_count={len(self._shards)} " + f"devices={self.devices} " f"of {self.shards[0].shape})" ) @@ -1243,13 +1288,14 @@ def __init__( self, *, ts: list[torch.Tensor], + devices: Optional[list] = None, name: str = UnnamedTensorName, shape: Optional[list[int]] = None, ): assert len(ts) > 0 shape = list(ts[0].shape if shape is None else shape) assert all(shape == list(t.shape) for t in ts) - super().__init__(name=name, ts=ts, shape=shape, shard_dim=None) + super().__init__(name=name, ts=ts, shape=shape, shard_dim=None, devices=devices) def flatten_tensor_tree( diff --git a/sharktank/sharktank/types/theta.py b/sharktank/sharktank/types/theta.py index 021925169..8c6d235dd 100644 --- a/sharktank/sharktank/types/theta.py +++ b/sharktank/sharktank/types/theta.py @@ -25,8 +25,6 @@ from .tensors import ( InferenceTensor, - PrimitiveTensor, - QuantizedTensor, InferenceTensorMetadata, DefaultPrimitiveTensor, REGISTERED_INFERENCE_TENSOR_CLASSES, diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index 99917c2d3..c90e03cea 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -69,9 +69,14 @@ def add_model_options(parser: argparse.ArgumentParser): default="decomposed", choices=["decomposed", "torch"], ) + parser.add_argument( + "--skip-prefill", + help="Skips exporting prefill", + action="store_true", + ) parser.add_argument( "--skip-decode", - help="Enables prefill only, skips decode", + help="Skips exporting decode", action="store_true", ) diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py index 65b42c986..228535ee8 100644 --- a/sharktank/tests/layers/kv_cache_test.py +++ b/sharktank/tests/layers/kv_cache_test.py @@ -8,6 +8,7 @@ import torch +from iree.turbine.aot import DeviceAffinity from sharktank.ops import replicate, reshard_split, unshard from sharktank.layers import * from sharktank.types import * @@ -148,6 +149,8 @@ def test_sharded_direct(): write_seq_length = seq_length - 5 + devices = [DeviceAffinity(i) for i in range(shard_count)] + # Write a prefill in: write_ones = reshard_split( torch.full( @@ -204,7 +207,7 @@ def test_sharded_direct(): ) write_pos = replicate( - torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count + torch.full((bs,), write_seq_length, dtype=torch.int64), devices ) cache.write_timestep( allocation, @@ -379,11 +382,13 @@ def test_sharded_paged(): device=None, ) + devices = [DeviceAffinity(i) for i in range(shard_count)] + write_seq_length = seq_length - 4 page_count = bs * seq_length // block_seq_stride page_ids = torch.arange(page_count, dtype=torch.int64) page_ids = page_ids.view(bs, seq_length // block_seq_stride) - page_ids = replicate(page_ids, shard_count) + page_ids = replicate(page_ids, devices=devices) write_page_ids = page_ids[:, : write_seq_length // block_seq_stride] allocation = cache.allocate(page_count=page_count) @@ -458,7 +463,7 @@ def test_sharded_paged(): ) write_pos = replicate( - torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count + torch.full((bs,), write_seq_length, dtype=torch.int64), devices ) cache.write_timestep( diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py index d7b6a0b33..c54b20459 100644 --- a/sharktank/tests/layers/sharded_paged_kv_cache_test.py +++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py @@ -5,13 +5,16 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import unittest -from sharktank.layers import PagedKVCache import torch -from sharktank.utils import iterables_equal from copy import deepcopy from typing import List, Tuple + +from iree.turbine.aot import DeviceAffinity + from sharktank import ops +from sharktank.layers import PagedKVCache from sharktank.types import SplitPrimitiveTensor +from sharktank.utils import iterables_equal class ShardedPagedKVCacheTest(unittest.TestCase): @@ -31,6 +34,7 @@ def setUp(self): self.batch_size = 11 self.block_seq_len = 2 self.max_seq_len = self.block_seq_len * self.block_seq_stride + self.devices = [DeviceAffinity(i) for i in range(self.shard_count)] self.cache = PagedKVCache( transformer_block_count=self.transformer_block_count, @@ -131,7 +135,7 @@ def testRead(self): for t in read_into_partitions_snapshot ] ) - sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) + sharded_page_ids = ops.replicate(page_ids, devices=self.devices) self.sharded_cache.read( state=sharded_cache_state, read_into_partitions=sharded_read_into_partitions, @@ -179,8 +183,8 @@ def testWriteTimestep(self): for t in cache_partitions ] ) - sharded_seq_positions = ops.replicate(seq_positions, count=self.shard_count) - sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) + sharded_seq_positions = ops.replicate(seq_positions, devices=self.devices) + sharded_page_ids = ops.replicate(page_ids, devices=self.devices) self.sharded_cache.write_timestep( state=sharded_cache_state, cache_partitions=sharded_cache_partitions, @@ -224,7 +228,7 @@ def testWrite(self): for t in cache_partitions ] ) - sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) + sharded_page_ids = ops.replicate(page_ids, devices=self.devices) self.sharded_cache.write( state=sharded_cache_state, cache_partitions=sharded_cache_partitions, diff --git a/sharktank/tests/layers/sharded_rotary_embedding_test.py b/sharktank/tests/layers/sharded_rotary_embedding_test.py index f24b8313a..2d2d95691 100644 --- a/sharktank/tests/layers/sharded_rotary_embedding_test.py +++ b/sharktank/tests/layers/sharded_rotary_embedding_test.py @@ -7,6 +7,8 @@ import torch +from iree.turbine.aot import DeviceAffinity + from sharktank.layers import RotaryEmbeddingLayer from sharktank import ops from sharktank.types import ( @@ -27,6 +29,8 @@ def test_sharded_rotary_table(): max_seqlen = 128 rope_freq_base = None + devices = [DeviceAffinity(i) for i in range(4)] + # First we setup and get the default rotary embedding layer xq = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float) xk = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float) @@ -45,7 +49,7 @@ def test_sharded_rotary_table(): rope_dimension_count=rope_dims, max_seqlen=max_seqlen, rope_freq_base=rope_freq_base, - tensor_parallelism_size=4, + devices=devices, ) sq = shard_layer(xt=xq, start_index=0) sk = shard_layer(xt=xk, start_index=0) diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py index 386061731..b1d0aa745 100644 --- a/sharktank/tests/models/llama/sharded_llama_test.py +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -6,7 +6,7 @@ import unittest import pytest -from typing import Any, List, Tuple, OrderedDict +from typing import Any, Tuple, OrderedDict from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 import sharktank.ops as ops from sharktank.types import unbox_tensor, Dataset, UnreducedTensor, SplitPrimitiveTensor diff --git a/sharktank/tests/ops/sharded_test.py b/sharktank/tests/ops/sharded_test.py index e5efaa948..adfb38359 100644 --- a/sharktank/tests/ops/sharded_test.py +++ b/sharktank/tests/ops/sharded_test.py @@ -630,7 +630,7 @@ def testAttentionShardedBatchMask(self): q_s = SplitPrimitiveTensor(shard_dim=0, ts=q.split(1, dim=0)) k_s = SplitPrimitiveTensor(shard_dim=0, ts=k.split(1, dim=0)) v_s = SplitPrimitiveTensor(shard_dim=0, ts=v.split(1, dim=0)) - a_s = ReplicatedTensor(ts=a, shard_count=4) + a_s = ReplicatedTensor(ts=[a] * 4) expected_result = ops.scaled_dot_product_attention( q, k, v, a=a, is_causal=False @@ -754,7 +754,7 @@ def testShardedLhsReplcatedRhs(self): expected_result = torch.matmul(a, b) shard_count = 3 a_sharded = SplitPrimitiveTensor(ts=a, shard_dim=1, shard_count=shard_count) - b_sharded = ReplicatedTensor(ts=b, shard_count=shard_count) + b_sharded = ReplicatedTensor(ts=[b] * shard_count) res_sharded = ops.matmul(a_sharded, b_sharded) assert isinstance(res_sharded, SplitPrimitiveTensor) assert res_sharded.shard_dim == 1 @@ -837,7 +837,7 @@ def testReplicateUnsharded(self): tensor = torch.rand(4, 5, dtype=torch.float32) shard_count = 3 actual_result = ops.replicate(tensor, count=shard_count) - expected_result = ReplicatedTensor(ts=tensor, shard_count=shard_count) + expected_result = ReplicatedTensor(ts=[tensor] * shard_count) assert expected_result.is_deep_equal(actual_result) # Test that is a copy. diff --git a/sharktank/tests/types/tensors_test.py b/sharktank/tests/types/tensors_test.py index 4af4a513f..37898a7c5 100644 --- a/sharktank/tests/types/tensors_test.py +++ b/sharktank/tests/types/tensors_test.py @@ -106,7 +106,7 @@ def testUnreducedTensorSaveLoad(self): def testReplicatedTensorExtractSlice(self): tensor = torch.rand([2, 3, 4], dtype=torch.float32) - replicated_tensor = ReplicatedTensor(ts=tensor, shard_count=3) + replicated_tensor = ReplicatedTensor(ts=[tensor] * 3) s = [slice(1, 2), slice(0, 3, 2), None] expected_result = tensor[s] replicated_sliced_tensor = replicated_tensor[s] @@ -116,7 +116,7 @@ def testReplicatedTensorExtractSlice(self): def testReplicatedTensorExtractElement(self): tensor = torch.rand([2, 3, 4], dtype=torch.float32) - replicated_tensor = ReplicatedTensor(ts=tensor, shard_count=3) + replicated_tensor = ReplicatedTensor(ts=[tensor] * 3) idx = ( 1, 2,