Skip to content

Commit

Permalink
Expanded sharded support for alternative sharding mechanisms
Browse files Browse the repository at this point in the history
Single-logical-multi-physical sharding allows tensor access between
different devices and tighter synchronization on execution. This means
that sharding needs to support more than differing device ordinals but
also configre multiple queues for the same device. Sharded tensor types
are reworked to support tracking both the supported device AND the queue
it is enqueued on.

To support this each sharded tensor now tracks the DeviceAffinity it is
associated with, along with reassigning affinities post construction.
This allows pre-sharded models to have their affinities updated with an
alternative transfer mechanism.

If device affinity is not specified the default arrangement assumes
separate device ordinals for each shard.
  • Loading branch information
rsuderman committed Dec 12, 2024
1 parent 4c015d4 commit afbb8e6
Show file tree
Hide file tree
Showing 10 changed files with 299 additions and 149 deletions.
82 changes: 55 additions & 27 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

# TODO: Should be using a base class with the protocol supported.
from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
from ..models.llama.sharding import shard_theta
from ..models.mixtral.mixtral import *
from ..models.grok.grok import *
from .. import ops

from typing import Union, Sequence


def main():
from ..utils import cli
Expand Down Expand Up @@ -54,6 +55,11 @@ def main():
help="Enables strictness during export",
action="store_true",
)
parser.add_argument(
"--use-queue-affinities",
help="Enables queue affinities for multidevice",
action="store_true",
)

cli.add_quantization_options(parser)
cli.add_model_options(parser)
Expand All @@ -73,18 +79,44 @@ def main():
tensor_parallelism_size=tensor_parallelism_size,
use_hf=False,
static_tables=False, # Rely on the compiler for hoisting tables.
kv_cache_type="direct" if args.bs == [1] else "paged",
kv_cache_type="paged" if args.bs == [1] else "paged",
attention_kernel=args.attention_kernel,
)
llama_config.fake_quant = args.fake_quant

def setup_queue_affinities(sharding):
return [DeviceAffinity(0, [str(i)]) for i in range(sharding)]

def assign_affinities(theta, affinities):
def transform(tensors: Union[None, InferenceTensor, Sequence[InferenceTensor]]):
if tensors is None:
return tensors
if isinstance(tensors, Sequence):
return [transform(t) for t in tensors]
if isinstance(tensors, ShardedTensor):
tensors.assign_affinities(affinities)
return tensors
if isinstance(tensors, InferenceTensor):
return tensors

raise ValueError("Unknown device for reassigned affinities")

return theta.transform(transform)

affinities = [
DeviceAffinity(i) for i in range(llama_config.tensor_parallelism_size)
]
if args.use_queue_affinities:
affinities = setup_queue_affinities(llama_config.tensor_parallelism_size)
dataset.root_theta = assign_affinities(dataset.root_theta, affinities)

if llama_config.hp.expert_count:
if llama_config.hp.model_arch == "grok":
model = PagedGrokModelV1(dataset.root_theta, llama_config)
else:
model = PagedMixtralModelV1(dataset.root_theta, llama_config)
else:
model = PagedLlamaModelV1(dataset.root_theta, llama_config)
model = PagedLlamaModelV1(dataset.root_theta, llama_config, affinities)

def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]):
return {
Expand All @@ -108,7 +140,7 @@ def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]):

fxb = FxProgramsBuilder(model)

def setup_cache(model, shard_count):
def setup_cache(model, affinities):
if model.config.kv_cache_type == "paged":
cache_state = model.cache.allocate(
page_count=hp.context_length // llama_config.block_seq_stride
Expand All @@ -130,24 +162,21 @@ def setup_cache(model, shard_count):
]

for i in range(llama_config.tensor_parallelism_size):
arg_affinities[i] = DeviceAffinity(str(i))
arg_affinities[i] = affinities[i]

return unpacked, shard_dim, dynamic_shapes, arg_affinities

elif model.config.kv_cache_type == "direct":
cache_state = model.cache.allocate(bs=1)
# Direct cache dimensions:
# 2 * transformer_block_count of...
# [bs, seq_length, attn_head_count, attn_head_dim]
dynamic_shapes = [None]
arg_affinities = {}
shard_dim = None
return torch.stack(cache_state), shard_dim, dynamic_shapes, arg_affinities
raise NotImplementedError(f"Direct cache is not currently functional")

else:
raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}")

def repack_cache(cache, shard_dim):
return [SplitPrimitiveTensor(ts=c, shard_dim=shard_dim) for c in cache]
return [
SplitPrimitiveTensor(ts=c, shard_dim=shard_dim, devices=affinities)
for c in cache
]

def generate_batch_prefill(bs: int):
# torch.export.Dim would make min at least 2
Expand All @@ -164,7 +193,7 @@ def generate_batch_prefill(bs: int):
seq_lens = torch.empty(bs, dtype=torch.int64)

cache, cache_shard_dim, cache_dynamic_shapes, arg_affinities = setup_cache(
model, llama_config.tensor_parallelism_size
model, affinities
)

if llama_config.tensor_parallelism_size > 1:
Expand Down Expand Up @@ -206,9 +235,9 @@ def _(model, tokens, seq_lens, seq_block_ids, cs):
if llama_config.tensor_parallelism_size != 1:
shard_count = llama_config.tensor_parallelism_size

tokens = ops.replicate(tokens, count=shard_count)
attention_mask = ops.replicate(attention_mask, count=shard_count)
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)
tokens = ops.replicate(tokens, devices=affinities)
attention_mask = ops.replicate(attention_mask, devices=affinities)
seq_block_ids = ops.replicate(seq_block_ids, devices=affinities)

cache_tensors = repack_cache(cs, cache_shard_dim)

Expand Down Expand Up @@ -243,15 +272,15 @@ def generate_batch_decode(bs: int):
cache_shard_dim,
cache_dynamic_shapes,
arg_affinities,
) = setup_cache(model, llama_config.tensor_parallelism_size)
) = setup_cache(model, affinities)

if llama_config.tensor_parallelism_size > 1:
# We need to offset the indices for the cache
arg_affinities = {key + 4: arg_affinities[key] for key in arg_affinities}

# Inputs have default affinity 0
for i in range(4):
arg_affinities[i] = DeviceAffinity("0")
arg_affinities[i] = affinities[0]

dynamic_shapes = {
"tokens": {},
Expand Down Expand Up @@ -290,12 +319,10 @@ def _(
attention_mask = model.decode_attention_mask(input_mask)

if llama_config.tensor_parallelism_size != 1:
shard_count = llama_config.tensor_parallelism_size

tokens = ops.replicate(tokens, count=shard_count)
attention_mask = ops.replicate(attention_mask, count=shard_count)
start_positions = ops.replicate(start_positions, count=shard_count)
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)
tokens = ops.replicate(tokens, devices=affinities)
attention_mask = ops.replicate(attention_mask, devices=affinities)
start_positions = ops.replicate(start_positions, devices=affinities)
seq_block_ids = ops.replicate(seq_block_ids, devices=affinities)

cache_state = repack_cache(cache_state, cache_shard_dim)

Expand All @@ -314,7 +341,8 @@ def _(

bsizes = []
for bs in args.bs:
generate_batch_prefill(bs)
if not args.skip_prefill:
generate_batch_prefill(bs)
if not args.skip_decode:
generate_batch_decode(bs)
bsizes.append(bs)
Expand Down
23 changes: 13 additions & 10 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import torch

from iree.turbine.aot import DeviceAffinity

from ..layers import *
from ..types import *

Expand Down Expand Up @@ -156,9 +158,10 @@ def prefill(self):
token_ids = self.token_ids
if model.config.tensor_parallelism_size != 1:
tp = model.config.tensor_parallelism_size
token_ids = replicate(token_ids, tp)
attention_mask = replicate(attention_mask, tp)
seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp)
devices = [DeviceAffinity(i) for i in range(tp)]
token_ids = replicate(token_ids, devices)
attention_mask = replicate(attention_mask, devices)
seq_block_ids_tensor = replicate(seq_block_ids_tensor, devices)

logits = model.prefill(
token_ids,
Expand Down Expand Up @@ -199,10 +202,11 @@ def decode(self):

if model.config.tensor_parallelism_size != 1:
tp = model.config.tensor_parallelism_size
self.next_tokens = replicate(self.next_tokens, tp)
start_positions = replicate(start_positions, tp)
seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp)
decode_attention_mask = replicate(decode_attention_mask, tp)
devices = [DeviceAffinity(i) for i in range(tp)]
self.next_tokens = replicate(self.next_tokens, devices)
start_positions = replicate(start_positions, devices)
seq_block_ids_tensor = replicate(seq_block_ids_tensor, devices)
decode_attention_mask = replicate(decode_attention_mask, devices)

logits = model.decode(
self.next_tokens,
Expand Down Expand Up @@ -279,16 +283,15 @@ def main():
tensor_parallelism_size=args.tensor_parallelism_size,
fake_quant=args.fake_quant,
)
if config.tensor_parallelism_size > 1:
dataset.root_theta = shard_theta(dataset.root_theta, config)
affinities = [DeviceAffinity(i) for i in range(args.tensor_parallelism_size)]

if config.hp.expert_count:
if config.hp.model_arch == "grok":
model = PagedGrokModelV1(dataset.root_theta, config)
else:
model = PagedMixtralModelV1(dataset.root_theta, config)
else:
model = PagedLlamaModelV1(dataset.root_theta, config)
model = PagedLlamaModelV1(dataset.root_theta, config, devices=affinities)

if args.save_intermediates_path:
from ..utils.patching import SaveModuleResultTensorsPatch
Expand Down
4 changes: 3 additions & 1 deletion sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,9 @@ def unflatten_page_table(
shards = [
shard.unflatten(1, self.sub_page_dims) for shard in page_slab.shards
]
return SplitPrimitiveTensor(ts=shards, shard_dim=4)
return SplitPrimitiveTensor(
ts=shards, shard_dim=4, devices=page_slab.devices
)

def shard_state(
self, state: List[torch.Tensor]
Expand Down
19 changes: 12 additions & 7 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,19 @@ def __init__(
use_hf: bool = False,
static_tables: bool = False,
use_table: bool = True,
tensor_parallelism_size: int = 1,
devices: list | None = None,
):
super().__init__()
self.device = device
self.devices = devices
self.rope_dimension_count = rope_dimension_count
self.max_seqlen = max_seqlen
self.use_hf = use_hf
self.static_tables = static_tables
self.use_table = use_table

self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0
self.tensor_parallelism_size = tensor_parallelism_size
self.devices = devices
if static_tables:
ops.module_register_buffer(
self, "static_rotary_embed_table", self._create_rotary_embed_table()
Expand Down Expand Up @@ -80,7 +81,9 @@ def forward(
)
for xt_shard, rotary_shard in zip(xt.shards, rotary_shards)
]
xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
xt = SplitPrimitiveTensor(
ts=xt_shards, shard_dim=xt.shard_dim, devices=xt.devices
)
return xt
else:
return self.forward_unsharded(
Expand Down Expand Up @@ -189,7 +192,7 @@ def compute_batch_mask(
self._compute_rotary_embed_table(s.flatten()).unflatten(0, shape)
for s in positions_seq.shards
]
freqs_cis = ReplicatedTensor(ts=ts)
freqs_cis = ReplicatedTensor(ts=ts, devices=positions_seq.devices)
else:
freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten())
freqs_cis = freqs_cis.unflatten(0, shape)
Expand All @@ -215,7 +218,9 @@ def apply_batched_mask(
)
for xt_shard, mask_shard in zip(xt.shards, mask.shards)
]
xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
xt = SplitPrimitiveTensor(
ts=xt_shards, shard_dim=xt.shard_dim, devices=xt.devices
)
return xt

def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor):
Expand Down Expand Up @@ -259,8 +264,8 @@ def _create_rotary_embed_table(self):
return self._replicate(freqs_cis)

def _replicate(self, t):
if self.tensor_parallelism_size > 1:
if self.devices is not None and len(self.devices) > 1:
# Replicate across all devices, the data is not a lot and the computation is cheap.
t = ops.replicate(t, self.tensor_parallelism_size)
t = ops.replicate(t, self.devices)

return t
12 changes: 8 additions & 4 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class PagedLlamaModelV1(BaseCausalLMModel):
unsharded result or chain it with other tensor-parallel operations.
"""

def __init__(self, theta: Theta, config: LlamaModelConfig):
def __init__(self, theta: Theta, config: LlamaModelConfig, devices: list):
hp = config.hp
super().__init__(
theta,
Expand Down Expand Up @@ -91,9 +91,9 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
rope_freq_base=hp.rope_freq_base,
max_seqlen=hp.context_length,
device=self.device,
devices=devices,
use_hf=self.use_hf,
static_tables=config.static_tables,
tensor_parallelism_size=config.tensor_parallelism_size,
),
)
self.add_module(
Expand Down Expand Up @@ -238,8 +238,12 @@ def decode(
)
for _ in range(self.config.tensor_parallelism_size)
]
xk_temp = SplitPrimitiveTensor(ts=xk_temp_shard, shard_dim=2)
xv_temp = SplitPrimitiveTensor(ts=xv_temp_shard, shard_dim=2)
xk_temp = SplitPrimitiveTensor(
ts=xk_temp_shard, shard_dim=2, devices=tokens.devices
)
xv_temp = SplitPrimitiveTensor(
ts=xv_temp_shard, shard_dim=2, devices=tokens.devices
)

h = self.token_embedding(tokens)
self.trace_tensor("llama.token_embedding", h)
Expand Down
Loading

0 comments on commit afbb8e6

Please sign in to comment.