Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Dec 12, 2024
1 parent afbb8e6 commit e486ad4
Show file tree
Hide file tree
Showing 11 changed files with 83 additions and 37 deletions.
6 changes: 5 additions & 1 deletion sharktank/sharktank/examples/sharding/export_ffn_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import torch
import torch.nn as nn

from iree.turbine.aot import DeviceAffinity

from ...layers import *
from ... import ops
from ...types import *
Expand Down Expand Up @@ -50,7 +52,9 @@ def forward(self, x: torch.Tensor):
ffn_gate_weight = self.theta.tensor("ffn_gate", "weight")
ffn_up_weight = self.theta.tensor("ffn_up", "weight")
ffn_down_weight = self.theta.tensor("ffn_down", "weight")
x = ops.replicate(x, count=ffn_gate_weight.shard_count)

devices = [DeviceAffinity(i) for i in range(ffn_down_weight.shard_count)]
x = ops.replicate(x, devices=devices)
ffn_gate = ops.elementwise(
torch.nn.functional.silu, ops.linear(x, ffn_gate_weight)
)
Expand Down
7 changes: 6 additions & 1 deletion sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from dataclasses import dataclass
from typing import Union

from iree.turbine.aot import DeviceAffinity

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -62,7 +64,7 @@ class PagedLlamaModelV1(BaseCausalLMModel):
unsharded result or chain it with other tensor-parallel operations.
"""

def __init__(self, theta: Theta, config: LlamaModelConfig, devices: list):
def __init__(self, theta: Theta, config: LlamaModelConfig, devices: list = None):
hp = config.hp
super().__init__(
theta,
Expand All @@ -80,6 +82,9 @@ def __init__(self, theta: Theta, config: LlamaModelConfig, devices: list):
self.use_hf = config.use_hf
self.attention_kernel = config.attention_kernel

if devices is None:
devices = [DeviceAffinity(i) for i in range(config.tensor_parallelism_size)]

self.add_module(
"token_embedding",
TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype),
Expand Down
42 changes: 33 additions & 9 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,38 +923,62 @@ def repeat_replicated(input: ReplicatedTensor, *sizes: List[int]) -> ReplicatedT


@replicate.override(ReplicatedTensor)
def replicate_replicated(input: ReplicatedTensor, *, devices: list) -> ReplicatedTensor:
if input.shard_count != len(devices):
def replicate_replicated(
input: ReplicatedTensor, *, devices: list, count: int
) -> ReplicatedTensor:
if devices is not None and input.shard_count != len(devices):
raise ValueError(
f"Number of shards not equal ({input.shard_count} != {len(devices)})"
)
if count is not None and input.shard_count != count:
raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})")
return input


@replicate.override(SplitPrimitiveTensor)
def replicate_split(input: SplitPrimitiveTensor, *, devices: list) -> ReplicatedTensor:
if input.shard_count != len(devices):
def replicate_split(
input: SplitPrimitiveTensor, *, devices: list, count: int
) -> ReplicatedTensor:
if devices is not None and input.shard_count != len(devices):
raise ValueError(
f"Number of shards not equal ({input.shard_count} != {len(devices)})"
)
if count is not None and input.shard_count != count:
raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})")
return all_gather(input)


@replicate.override(UnreducedTensor)
def replicate_unreduced(input: UnreducedTensor, *, devices: list) -> ReplicatedTensor:
if input.shard_count != len(devices):
def replicate_unreduced(
input: UnreducedTensor, *, devices: list, count: int
) -> ReplicatedTensor:
if devices is not None and input.shard_count != len(devices):
raise ValueError(
f"Number of shards not equal ({input.shard_count} != {len(devices)})"
)
if count is not None and input.shard_count != count:
raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})")
return all_reduce(input)


@replicate.override(Tensor)
def replicate_unsharded(input, *, devices: list) -> ReplicatedTensor:
def replicate_unsharded(input, *, devices: list, count: int) -> ReplicatedTensor:
torch_input = unbox_tensor(input)
# If we have a torch input replicating we can assume we need to transfer:
torch_inputs = [transfer_to_logical_device(torch_input, d.ordinal) for d in devices]
return ReplicatedTensor(ts=torch_inputs, devices=devices)
if devices is not None:
torch_inputs = [
transfer_to_logical_device(torch_input, d.ordinal) for d in devices
]
return ReplicatedTensor(ts=torch_inputs, devices=devices)

if count is not None:
devices = [DeviceAffinity(i) for i in range(count)]
torch_inputs = [
transfer_to_logical_device(torch_input, i) for i in range(count)
]
return ReplicatedTensor(ts=torch_inputs, devices=devices)

raise ValueError(f"Devices or count is required")


@reshape.override(SplitPrimitiveTensor)
Expand Down
8 changes: 5 additions & 3 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,9 @@ def _repeat_trampoline(


@overridable
def replicate(input: AnyTensor, devices: list) -> ShardedTensor:
def replicate(
input: AnyTensor, devices: list = None, count: int = None
) -> ShardedTensor:
"""Replicate across devices.
Possibly reshards if required."""
Expand All @@ -805,11 +807,11 @@ def replicate(input: AnyTensor, devices: list) -> ShardedTensor:

@replicate.trampoline
def _replicate_trampoline(
d: SignatureDispatcher, input: AnyTensor, devices: list
d: SignatureDispatcher, input: AnyTensor, devices: list = None, count: int = None
) -> ShardedTensor:
tensors = (input,)
for override in d.find_overrides(tensors):
result = override(input, devices=devices)
result = override(input, devices=devices, count=count)
if result is not NotImplemented:
return override, result
else:
Expand Down
12 changes: 5 additions & 7 deletions sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,9 +813,6 @@ def __init__(

self._devices: tuple[DeviceAffinity] = tuple(devices)

for i, t in enumerate(ts):
DeviceTensorTrait(i).set(t)

def assign_affinities(self, affinities):
assert len(affinities) == len(self._devices)
self._devices = tuple(affinities)
Expand Down Expand Up @@ -893,6 +890,7 @@ def create(
t_name = str(i)
try:
t = raw_tensors[t_name]
DeviceTensorTrait(i).set(t)
ts.append(t)
except KeyError as e:
raise IOError(
Expand Down Expand Up @@ -996,7 +994,7 @@ def __init__(
number of pieces.
"""

assert shard_count is None or not isinstance(ts, torch.Tensor)
assert shard_count is None or isinstance(ts, torch.Tensor)
shard_count = shard_count if shard_count is not None else len(ts)

if devices is None:
Expand Down Expand Up @@ -1169,9 +1167,6 @@ def __init__(

self._devices: tuple[DeviceAffinity] = tuple(devices)

for d, t in zip(devices, ts):
DeviceTensorTrait(d.ordinal, d.queues).set(t)

def assign_affinities(self, affinities):
assert len(affinities) == len(self._devices)
self._devices = tuple(affinities)
Expand Down Expand Up @@ -1238,6 +1233,9 @@ def create(
nt = deepcopy(t)
ts.append(nt)

for i, t in enumerate(ts):
DeviceTensorTrait(i).set(t)

except KeyError as e:
raise IOError(f"Missing component tensor '' in {raw_tensors.keys()}") from e
return cls(name=name, ts=ts)
Expand Down
11 changes: 8 additions & 3 deletions sharktank/tests/layers/kv_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch

from iree.turbine.aot import DeviceAffinity
from sharktank.ops import replicate, reshard_split, unshard
from sharktank.layers import *
from sharktank.types import *
Expand Down Expand Up @@ -148,6 +149,8 @@ def test_sharded_direct():

write_seq_length = seq_length - 5

devices = [DeviceAffinity(i) for i in range(shard_count)]

# Write a prefill in:
write_ones = reshard_split(
torch.full(
Expand Down Expand Up @@ -204,7 +207,7 @@ def test_sharded_direct():
)

write_pos = replicate(
torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count
torch.full((bs,), write_seq_length, dtype=torch.int64), devices
)
cache.write_timestep(
allocation,
Expand Down Expand Up @@ -379,11 +382,13 @@ def test_sharded_paged():
device=None,
)

devices = [DeviceAffinity(i) for i in range(shard_count)]

write_seq_length = seq_length - 4
page_count = bs * seq_length // block_seq_stride
page_ids = torch.arange(page_count, dtype=torch.int64)
page_ids = page_ids.view(bs, seq_length // block_seq_stride)
page_ids = replicate(page_ids, shard_count)
page_ids = replicate(page_ids, devices=devices)
write_page_ids = page_ids[:, : write_seq_length // block_seq_stride]

allocation = cache.allocate(page_count=page_count)
Expand Down Expand Up @@ -458,7 +463,7 @@ def test_sharded_paged():
)

write_pos = replicate(
torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count
torch.full((bs,), write_seq_length, dtype=torch.int64), devices
)

cache.write_timestep(
Expand Down
16 changes: 10 additions & 6 deletions sharktank/tests/layers/sharded_paged_kv_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import unittest
from sharktank.layers import PagedKVCache
import torch
from sharktank.utils import iterables_equal
from copy import deepcopy
from typing import List, Tuple

from iree.turbine.aot import DeviceAffinity

from sharktank import ops
from sharktank.layers import PagedKVCache
from sharktank.types import SplitPrimitiveTensor
from sharktank.utils import iterables_equal


class ShardedPagedKVCacheTest(unittest.TestCase):
Expand All @@ -31,6 +34,7 @@ def setUp(self):
self.batch_size = 11
self.block_seq_len = 2
self.max_seq_len = self.block_seq_len * self.block_seq_stride
self.devices = [DeviceAffinity(i) for i in range(self.shard_count)]

self.cache = PagedKVCache(
transformer_block_count=self.transformer_block_count,
Expand Down Expand Up @@ -131,7 +135,7 @@ def testRead(self):
for t in read_into_partitions_snapshot
]
)
sharded_page_ids = ops.replicate(page_ids, count=self.shard_count)
sharded_page_ids = ops.replicate(page_ids, devices=self.devices)
self.sharded_cache.read(
state=sharded_cache_state,
read_into_partitions=sharded_read_into_partitions,
Expand Down Expand Up @@ -179,8 +183,8 @@ def testWriteTimestep(self):
for t in cache_partitions
]
)
sharded_seq_positions = ops.replicate(seq_positions, count=self.shard_count)
sharded_page_ids = ops.replicate(page_ids, count=self.shard_count)
sharded_seq_positions = ops.replicate(seq_positions, devices=self.devices)
sharded_page_ids = ops.replicate(page_ids, devices=self.devices)
self.sharded_cache.write_timestep(
state=sharded_cache_state,
cache_partitions=sharded_cache_partitions,
Expand Down Expand Up @@ -224,7 +228,7 @@ def testWrite(self):
for t in cache_partitions
]
)
sharded_page_ids = ops.replicate(page_ids, count=self.shard_count)
sharded_page_ids = ops.replicate(page_ids, devices=self.devices)
self.sharded_cache.write(
state=sharded_cache_state,
cache_partitions=sharded_cache_partitions,
Expand Down
6 changes: 5 additions & 1 deletion sharktank/tests/layers/sharded_rotary_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import torch

from iree.turbine.aot import DeviceAffinity

from sharktank.layers import RotaryEmbeddingLayer
from sharktank import ops
from sharktank.types import (
Expand All @@ -27,6 +29,8 @@ def test_sharded_rotary_table():
max_seqlen = 128
rope_freq_base = None

devices = [DeviceAffinity(i) for i in range(4)]

# First we setup and get the default rotary embedding layer
xq = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float)
xk = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float)
Expand All @@ -45,7 +49,7 @@ def test_sharded_rotary_table():
rope_dimension_count=rope_dims,
max_seqlen=max_seqlen,
rope_freq_base=rope_freq_base,
tensor_parallelism_size=4,
devices=devices,
)
sq = shard_layer(xt=xq, start_index=0)
sk = shard_layer(xt=xk, start_index=0)
Expand Down
2 changes: 1 addition & 1 deletion sharktank/tests/models/llama/sharded_llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import unittest
import pytest
from typing import Any, List, Tuple, OrderedDict
from typing import Any, Tuple, OrderedDict
from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
import sharktank.ops as ops
from sharktank.types import unbox_tensor, Dataset, UnreducedTensor, SplitPrimitiveTensor
Expand Down
6 changes: 3 additions & 3 deletions sharktank/tests/ops/sharded_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def testAttentionShardedBatchMask(self):
q_s = SplitPrimitiveTensor(shard_dim=0, ts=q.split(1, dim=0))
k_s = SplitPrimitiveTensor(shard_dim=0, ts=k.split(1, dim=0))
v_s = SplitPrimitiveTensor(shard_dim=0, ts=v.split(1, dim=0))
a_s = ReplicatedTensor(ts=a, shard_count=4)
a_s = ReplicatedTensor(ts=[a] * 4)

expected_result = ops.scaled_dot_product_attention(
q, k, v, a=a, is_causal=False
Expand Down Expand Up @@ -754,7 +754,7 @@ def testShardedLhsReplcatedRhs(self):
expected_result = torch.matmul(a, b)
shard_count = 3
a_sharded = SplitPrimitiveTensor(ts=a, shard_dim=1, shard_count=shard_count)
b_sharded = ReplicatedTensor(ts=b, shard_count=shard_count)
b_sharded = ReplicatedTensor(ts=[b] * shard_count)
res_sharded = ops.matmul(a_sharded, b_sharded)
assert isinstance(res_sharded, SplitPrimitiveTensor)
assert res_sharded.shard_dim == 1
Expand Down Expand Up @@ -837,7 +837,7 @@ def testReplicateUnsharded(self):
tensor = torch.rand(4, 5, dtype=torch.float32)
shard_count = 3
actual_result = ops.replicate(tensor, count=shard_count)
expected_result = ReplicatedTensor(ts=tensor, shard_count=shard_count)
expected_result = ReplicatedTensor(ts=[tensor] * shard_count)
assert expected_result.is_deep_equal(actual_result)

# Test that is a copy.
Expand Down
4 changes: 2 additions & 2 deletions sharktank/tests/types/tensors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def testUnreducedTensorSaveLoad(self):

def testReplicatedTensorExtractSlice(self):
tensor = torch.rand([2, 3, 4], dtype=torch.float32)
replicated_tensor = ReplicatedTensor(ts=tensor, shard_count=3)
replicated_tensor = ReplicatedTensor(ts=[tensor] * 3)
s = [slice(1, 2), slice(0, 3, 2), None]
expected_result = tensor[s]
replicated_sliced_tensor = replicated_tensor[s]
Expand All @@ -116,7 +116,7 @@ def testReplicatedTensorExtractSlice(self):

def testReplicatedTensorExtractElement(self):
tensor = torch.rand([2, 3, 4], dtype=torch.float32)
replicated_tensor = ReplicatedTensor(ts=tensor, shard_count=3)
replicated_tensor = ReplicatedTensor(ts=[tensor] * 3)
idx = (
1,
2,
Expand Down

0 comments on commit e486ad4

Please sign in to comment.