Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expanded sharded support for alternative sharding mechanisms #680

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 55 additions & 27 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@

# TODO: Should be using a base class with the protocol supported.
from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
from ..models.llama.sharding import shard_theta
from ..models.mixtral.mixtral import *
from ..models.grok.grok import *
from .. import ops

from typing import Union, Sequence


def main():
from ..utils import cli
Expand Down Expand Up @@ -55,6 +56,11 @@ def main():
help="Enables strictness during export",
action="store_true",
)
parser.add_argument(
"--use-queue-affinities",
help="Enables queue affinities for multidevice",
action="store_true",
)

cli.add_quantization_options(parser)
cli.add_model_options(parser)
Expand All @@ -74,18 +80,44 @@ def main():
tensor_parallelism_size=tensor_parallelism_size,
use_hf=False,
static_tables=False, # Rely on the compiler for hoisting tables.
kv_cache_type="direct" if args.bs == [1] else "paged",
kv_cache_type="paged" if args.bs == [1] else "paged",
attention_kernel=args.attention_kernel,
)
llama_config.fake_quant = args.fake_quant

def setup_queue_affinities(sharding):
return [DeviceAffinity(0, [str(i)]) for i in range(sharding)]

def assign_affinities(theta, affinities):
def transform(tensors: Union[None, InferenceTensor, Sequence[InferenceTensor]]):
if tensors is None:
return tensors
if isinstance(tensors, Sequence):
return [transform(t) for t in tensors]
if isinstance(tensors, ShardedTensor):
tensors.assign_affinities(affinities)
return tensors
if isinstance(tensors, InferenceTensor):
return tensors

raise ValueError("Unknown device for reassigned affinities")

return theta.transform(transform)

affinities = [
DeviceAffinity(i) for i in range(llama_config.tensor_parallelism_size)
]
if args.use_queue_affinities:
affinities = setup_queue_affinities(llama_config.tensor_parallelism_size)
dataset.root_theta = assign_affinities(dataset.root_theta, affinities)

if llama_config.hp.expert_count:
if llama_config.hp.model_arch == "grok":
model = PagedGrokModelV1(dataset.root_theta, llama_config)
else:
model = PagedMixtralModelV1(dataset.root_theta, llama_config)
else:
model = PagedLlamaModelV1(dataset.root_theta, llama_config)
model = PagedLlamaModelV1(dataset.root_theta, llama_config, affinities)

def generate_params_json(
hp: LlamaHParams, prefill_bs: list[int], decode_bs: list[int]
Expand Down Expand Up @@ -121,7 +153,7 @@ def generate_params_json(

fxb = FxProgramsBuilder(model)

def setup_cache(model, shard_count):
def setup_cache(model, affinities):
if model.config.kv_cache_type == "paged":
cache_state = model.cache.allocate(
page_count=hp.context_length // llama_config.block_seq_stride
Expand All @@ -143,24 +175,21 @@ def setup_cache(model, shard_count):
]

for i in range(llama_config.tensor_parallelism_size):
arg_affinities[i] = DeviceAffinity(str(i))
arg_affinities[i] = affinities[i]

return unpacked, shard_dim, dynamic_shapes, arg_affinities

elif model.config.kv_cache_type == "direct":
cache_state = model.cache.allocate(bs=1)
# Direct cache dimensions:
# 2 * transformer_block_count of...
# [bs, seq_length, attn_head_count, attn_head_dim]
dynamic_shapes = [None]
arg_affinities = {}
shard_dim = None
return torch.stack(cache_state), shard_dim, dynamic_shapes, arg_affinities
raise NotImplementedError(f"Direct cache is not currently functional")

else:
raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}")

def repack_cache(cache, shard_dim):
return [SplitPrimitiveTensor(ts=c, shard_dim=shard_dim) for c in cache]
return [
SplitPrimitiveTensor(ts=c, shard_dim=shard_dim, devices=affinities)
for c in cache
]

def generate_batch_prefill(bs: int):
# torch.export.Dim would make min at least 2
Expand All @@ -177,7 +206,7 @@ def generate_batch_prefill(bs: int):
seq_lens = torch.empty(bs, dtype=torch.int64)

cache, cache_shard_dim, cache_dynamic_shapes, arg_affinities = setup_cache(
model, llama_config.tensor_parallelism_size
model, affinities
)

if llama_config.tensor_parallelism_size > 1:
Expand Down Expand Up @@ -219,9 +248,9 @@ def _(model, tokens, seq_lens, seq_block_ids, cs):
if llama_config.tensor_parallelism_size != 1:
shard_count = llama_config.tensor_parallelism_size

tokens = ops.replicate(tokens, count=shard_count)
attention_mask = ops.replicate(attention_mask, count=shard_count)
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)
tokens = ops.replicate(tokens, devices=affinities)
attention_mask = ops.replicate(attention_mask, devices=affinities)
seq_block_ids = ops.replicate(seq_block_ids, devices=affinities)

cache_tensors = repack_cache(cs, cache_shard_dim)

Expand Down Expand Up @@ -256,15 +285,15 @@ def generate_batch_decode(bs: int):
cache_shard_dim,
cache_dynamic_shapes,
arg_affinities,
) = setup_cache(model, llama_config.tensor_parallelism_size)
) = setup_cache(model, affinities)

if llama_config.tensor_parallelism_size > 1:
# We need to offset the indices for the cache
arg_affinities = {key + 4: arg_affinities[key] for key in arg_affinities}

# Inputs have default affinity 0
for i in range(4):
arg_affinities[i] = DeviceAffinity("0")
arg_affinities[i] = affinities[0]

dynamic_shapes = {
"tokens": {},
Expand Down Expand Up @@ -303,12 +332,10 @@ def _(
attention_mask = model.decode_attention_mask(input_mask)

if llama_config.tensor_parallelism_size != 1:
shard_count = llama_config.tensor_parallelism_size

tokens = ops.replicate(tokens, count=shard_count)
attention_mask = ops.replicate(attention_mask, count=shard_count)
start_positions = ops.replicate(start_positions, count=shard_count)
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)
tokens = ops.replicate(tokens, devices=affinities)
attention_mask = ops.replicate(attention_mask, devices=affinities)
start_positions = ops.replicate(start_positions, devices=affinities)
seq_block_ids = ops.replicate(seq_block_ids, devices=affinities)

cache_state = repack_cache(cache_state, cache_shard_dim)

Expand All @@ -327,7 +354,8 @@ def _(

bsizes = []
for bs in args.bs:
generate_batch_prefill(bs)
if not args.skip_prefill:
generate_batch_prefill(bs)
if not args.skip_decode:
generate_batch_decode(bs)
bsizes.append(bs)
Expand Down
23 changes: 13 additions & 10 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import torch

from iree.turbine.aot import DeviceAffinity

from ..layers import *
from ..types import *

Expand Down Expand Up @@ -156,9 +158,10 @@ def prefill(self):
token_ids = self.token_ids
if model.config.tensor_parallelism_size != 1:
tp = model.config.tensor_parallelism_size
token_ids = replicate(token_ids, tp)
attention_mask = replicate(attention_mask, tp)
seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp)
devices = [DeviceAffinity(i) for i in range(tp)]
token_ids = replicate(token_ids, devices)
attention_mask = replicate(attention_mask, devices)
seq_block_ids_tensor = replicate(seq_block_ids_tensor, devices)

logits = model.prefill(
token_ids,
Expand Down Expand Up @@ -199,10 +202,11 @@ def decode(self):

if model.config.tensor_parallelism_size != 1:
tp = model.config.tensor_parallelism_size
self.next_tokens = replicate(self.next_tokens, tp)
start_positions = replicate(start_positions, tp)
seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp)
decode_attention_mask = replicate(decode_attention_mask, tp)
devices = [DeviceAffinity(i) for i in range(tp)]
self.next_tokens = replicate(self.next_tokens, devices)
start_positions = replicate(start_positions, devices)
seq_block_ids_tensor = replicate(seq_block_ids_tensor, devices)
decode_attention_mask = replicate(decode_attention_mask, devices)

logits = model.decode(
self.next_tokens,
Expand Down Expand Up @@ -279,16 +283,15 @@ def main():
tensor_parallelism_size=args.tensor_parallelism_size,
fake_quant=args.fake_quant,
)
if config.tensor_parallelism_size > 1:
dataset.root_theta = shard_theta(dataset.root_theta, config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove shard_theta import if unused.

affinities = [DeviceAffinity(i) for i in range(args.tensor_parallelism_size)]

if config.hp.expert_count:
if config.hp.model_arch == "grok":
model = PagedGrokModelV1(dataset.root_theta, config)
else:
model = PagedMixtralModelV1(dataset.root_theta, config)
else:
model = PagedLlamaModelV1(dataset.root_theta, config)
model = PagedLlamaModelV1(dataset.root_theta, config, devices=affinities)

if args.save_intermediates_path:
from ..utils.patching import SaveModuleResultTensorsPatch
Expand Down
6 changes: 5 additions & 1 deletion sharktank/sharktank/examples/sharding/export_ffn_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import torch
import torch.nn as nn

from iree.turbine.aot import DeviceAffinity

from ...layers import *
from ... import ops
from ...types import *
Expand Down Expand Up @@ -50,7 +52,9 @@ def forward(self, x: torch.Tensor):
ffn_gate_weight = self.theta.tensor("ffn_gate", "weight")
ffn_up_weight = self.theta.tensor("ffn_up", "weight")
ffn_down_weight = self.theta.tensor("ffn_down", "weight")
x = ops.replicate(x, count=ffn_gate_weight.shard_count)

devices = [DeviceAffinity(i) for i in range(ffn_down_weight.shard_count)]
x = ops.replicate(x, devices=devices)
ffn_gate = ops.elementwise(
torch.nn.functional.silu, ops.linear(x, ffn_gate_weight)
)
Expand Down
4 changes: 3 additions & 1 deletion sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,9 @@ def unflatten_page_table(
shards = [
shard.unflatten(1, self.sub_page_dims) for shard in page_slab.shards
]
return SplitPrimitiveTensor(ts=shards, shard_dim=4)
return SplitPrimitiveTensor(
ts=shards, shard_dim=4, devices=page_slab.devices
)

def shard_state(
self, state: List[torch.Tensor]
Expand Down
19 changes: 12 additions & 7 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,19 @@ def __init__(
use_hf: bool = False,
static_tables: bool = False,
use_table: bool = True,
tensor_parallelism_size: int = 1,
devices: list | None = None,
):
super().__init__()
self.device = device
self.devices = devices
self.rope_dimension_count = rope_dimension_count
self.max_seqlen = max_seqlen
self.use_hf = use_hf
self.static_tables = static_tables
self.use_table = use_table

self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0
self.tensor_parallelism_size = tensor_parallelism_size
self.devices = devices
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant with L34, can be removed.

if static_tables:
ops.module_register_buffer(
self, "static_rotary_embed_table", self._create_rotary_embed_table()
Expand Down Expand Up @@ -80,7 +81,9 @@ def forward(
)
for xt_shard, rotary_shard in zip(xt.shards, rotary_shards)
]
xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
xt = SplitPrimitiveTensor(
ts=xt_shards, shard_dim=xt.shard_dim, devices=xt.devices
)
return xt
else:
return self.forward_unsharded(
Expand Down Expand Up @@ -189,7 +192,7 @@ def compute_batch_mask(
self._compute_rotary_embed_table(s.flatten()).unflatten(0, shape)
for s in positions_seq.shards
]
freqs_cis = ReplicatedTensor(ts=ts)
freqs_cis = ReplicatedTensor(ts=ts, devices=positions_seq.devices)
else:
freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten())
freqs_cis = freqs_cis.unflatten(0, shape)
Expand All @@ -215,7 +218,9 @@ def apply_batched_mask(
)
for xt_shard, mask_shard in zip(xt.shards, mask.shards)
]
xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
xt = SplitPrimitiveTensor(
ts=xt_shards, shard_dim=xt.shard_dim, devices=xt.devices
)
return xt

def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor):
Expand Down Expand Up @@ -259,8 +264,8 @@ def _create_rotary_embed_table(self):
return self._replicate(freqs_cis)

def _replicate(self, t):
if self.tensor_parallelism_size > 1:
if self.devices is not None and len(self.devices) > 1:
# Replicate across all devices, the data is not a lot and the computation is cheap.
t = ops.replicate(t, self.tensor_parallelism_size)
t = ops.replicate(t, self.devices)

return t
17 changes: 13 additions & 4 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from dataclasses import dataclass
from typing import Union

from iree.turbine.aot import DeviceAffinity

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -62,7 +64,7 @@ class PagedLlamaModelV1(BaseCausalLMModel):
unsharded result or chain it with other tensor-parallel operations.
"""

def __init__(self, theta: Theta, config: LlamaModelConfig):
def __init__(self, theta: Theta, config: LlamaModelConfig, devices: list = None):
hp = config.hp
super().__init__(
theta,
Expand All @@ -80,6 +82,9 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
self.use_hf = config.use_hf
self.attention_kernel = config.attention_kernel

if devices is None:
devices = [DeviceAffinity(i) for i in range(config.tensor_parallelism_size)]

self.add_module(
"token_embedding",
TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype),
Expand All @@ -91,9 +96,9 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
rope_freq_base=hp.rope_freq_base,
max_seqlen=hp.context_length,
device=self.device,
devices=devices,
use_hf=self.use_hf,
static_tables=config.static_tables,
tensor_parallelism_size=config.tensor_parallelism_size,
),
)
self.add_module(
Expand Down Expand Up @@ -238,8 +243,12 @@ def decode(
)
for _ in range(self.config.tensor_parallelism_size)
]
xk_temp = SplitPrimitiveTensor(ts=xk_temp_shard, shard_dim=2)
xv_temp = SplitPrimitiveTensor(ts=xv_temp_shard, shard_dim=2)
xk_temp = SplitPrimitiveTensor(
ts=xk_temp_shard, shard_dim=2, devices=tokens.devices
)
xv_temp = SplitPrimitiveTensor(
ts=xv_temp_shard, shard_dim=2, devices=tokens.devices
)

h = self.token_embedding(tokens)
self.trace_tensor("llama.token_embedding", h)
Expand Down
Loading
Loading