Skip to content

Commit

Permalink
basic sharding support for quant tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
IanNod committed Oct 28, 2024
1 parent 6f3f8c7 commit f7e9a77
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 46 deletions.
4 changes: 2 additions & 2 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,8 @@ def to_default(tensor: Tensor, *args, **kwargs):
return unbox_tensor(tensor).to(*args, **kwargs)


@transfer_to_logical_device.override(Tensor)
def transfer_to_logical_device_default(tensor: Tensor, ordinal: int):
@transfer_to_logical_device.override(AllOfType(AnyTensor, QuantizedTensor))
def transfer_to_logical_device_default(tensor, ordinal: int):
return iree.turbine.ops.iree.transfer_to_logical_device(
f"{ordinal}", unbox_tensor(tensor)
)
Expand Down
83 changes: 83 additions & 0 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
AnyTensor,
DefaultPrimitiveTensor,
InferenceTensor,
QuantizedTensor,
PlanarQuantizedTensor,
PrimitiveTensor,
ReplicatedTensor,
ShardedTensor,
Expand All @@ -28,6 +30,8 @@
from .signatures import *
from .shape import broadcast_dims, broadcast_dim, unbroadcast_dim
from ..utils import longest_equal_range
from ..utils.math import ceildiv
from sharktank.types.tensors import REGISTERED_LAYOUT_CLASSES


@all_gather.override(SplitPrimitiveTensor)
Expand Down Expand Up @@ -1264,3 +1268,82 @@ def view_split(tensor: SplitPrimitiveTensor, shape: List[int]) -> SplitPrimitive
res = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards)
assert math.prod(res.shape) == math.prod(tensor.shape)
return res


@split.override(QuantizedTensor)
def split_QuantizedTensor(tensor: QuantizedTensor, split_size_or_sections, dim):
tensors = []
unpacked = tensor.unpack()
num_outputs = ceildiv(unpacked._qs.shape[dim], split_size_or_sections)
new_shape = unpacked._shape
new_shape[dim] = split_size_or_sections
new_qs = torch.split(unpacked._qs, split_size_or_sections, dim)
if unpacked._d.ndim > 0:
new_d = torch.split(unpacked._d, split_size_or_sections, dim)
if unpacked.serialized_name() == "SuperBlockOffsetScaled_4_6_Layout":
new_dmin = torch.split(unpacked._dmin, split_size_or_sections, dim)
new_sb_scales_high = torch.split(
unpacked._sb_scales_high, split_size_or_sections, dim
)
new_sb_scales_low = torch.split(
unpacked._sb_scales_low, split_size_or_sections, dim
)
new_sb_mins_high = torch.split(
unpacked._sb_mins_high, split_size_or_sections, dim
)
new_sb_mins_low = torch.split(
unpacked._sb_mins_low, split_size_or_sections, dim
)
for i in range(num_outputs):
layout_clazz = REGISTERED_LAYOUT_CLASSES[unpacked.serialized_name()]
new_layout = layout_clazz(
shape=new_shape,
d=new_d[i],
dmin=new_dmin[i],
sb_scales_high=new_sb_scales_high[i],
sb_scales_low=new_sb_scales_low[i],
sb_mins_high=new_sb_mins_high[i],
sb_mins_low=new_sb_mins_low[i],
qs=new_qs[i],
)
new_tensor = tensor.__class__
new_tensor_layout = new_layout.create(
new_layout.shape, new_layout.metadata, new_layout.planes
)
new_tensor = tensor.__class__(
shape=new_shape, layout=new_tensor_layout, name=tensor._name + str(i)
)
tensors.append(new_tensor)
else:
if split_size_or_sections > unpacked._qs.shape[dim]:
raise ValueError("split size greater than tensor dim")

if unpacked._m is not None:
if unpacked._m.ndim > 0:
new_m = torch.split(unpacked._m, split_size_or_sections, dim)
for i in range(num_outputs):
layout_clazz = REGISTERED_LAYOUT_CLASSES[unpacked.serialized_name()]
if unpacked._m is not None:
if unpacked._d.ndim > 0:
new_layout = layout_clazz(
shape=new_shape, d=new_d[i], qs=new_qs[i], m=new_m[i]
)
else:
new_layout = layout_clazz(
shape=new_shape, d=unpacked._d, qs=new_qs[i], m=unpacked._m
)
else:
if unpacked._d.ndim > 0:
new_layout = layout_clazz(shape=new_shape, d=new_d[i], qs=new_qs[i])
else:
new_layout = layout_clazz(
shape=new_shape, d=unpacked._d, qs=new_qs[i]
)
new_tensor_layout = new_layout.create(
new_layout.shape, new_layout.metadata, new_layout.planes
)
new_tensor = tensor.__class__(
shape=new_shape, layout=new_tensor_layout, name=tensor._name + str(i)
)
tensors.append(new_tensor)
return tensors
43 changes: 40 additions & 3 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@
import torch
import numbers
from torch import Tensor, dtype
from ..types import AnyTensor, ShardedTensor, Theta, sharding, InferenceTensor
from ..types import (
AnyTensor,
ShardedTensor,
Theta,
sharding,
InferenceTensor,
QuantizedTensor,
PlanarQuantizedTensor,
)
from numbers import Number

from ._registry import *
Expand Down Expand Up @@ -59,6 +67,7 @@
"unshard",
"unsqueeze",
"view",
"split",
]

IntOrSequenceInt = Union[int, Sequence[int]]
Expand Down Expand Up @@ -976,14 +985,18 @@ def _to_trampoline(d: SignatureDispatcher, tensor: AnyTensor, *args, **kwargs):


@overridable
def transfer_to_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor:
def transfer_to_logical_device(
tensor: Union[AnyTensor, QuantizedTensor, PlanarQuantizedTensor], ordinal: int
) -> Union[AnyTensor, QuantizedTensor, PlanarQuantizedTensor]:
"""Transfer the tensor to a device with ordinal `ordinal`."""
...


@transfer_to_logical_device.trampoline
def _transfer_to_logical_device_trampoline(
d: SignatureDispatcher, tensor: AnyTensor, ordinal: int
d: SignatureDispatcher,
tensor: Union[AnyTensor, QuantizedTensor, PlanarQuantizedTensor],
ordinal: int,
):
tensors = (tensor,)
for override in d.find_overrides(tensors):
Expand Down Expand Up @@ -1085,3 +1098,27 @@ def _view_trampoline(
return override, result
else:
d.fail(tensors)


@overridable
def split(
tensor: QuantizedTensor, split_size_or_sections: List[int], dim: int
) -> [QuantizedTensor]:
"""See torch.Tensor.split"""
...


@split.trampoline
def _split_trampoline(
d: SignatureDispatcher,
tensor: QuantizedTensor,
split_size_or_sections: List[int],
dim: int,
) -> [QuantizedTensor]:
tensors = (tensor,)
for override in d.find_overrides(tensors):
result = override(tensor, split_size_or_sections, dim)
if result is not NotImplemented:
return override, result
else:
d.fail(tensors)
20 changes: 11 additions & 9 deletions sharktank/sharktank/types/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,17 @@ def __init__(self, *args, **kwargs):
for k, v in d.items():
d[k] = tree.map_nodes(
tree=v,
f=lambda x: x
if isinstance(
x,
(
TensorSharding,
ThetaSharding,
),
)
else ThetaSharding(x),
f=lambda x: (
x
if isinstance(
x,
(
TensorSharding,
ThetaSharding,
),
)
else ThetaSharding(x)
),
)
super().__init__(d)

Expand Down
31 changes: 20 additions & 11 deletions sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,11 @@ def add_to_archive(self, builder: ShardedArchiveBuilder) -> InferenceTensorMetad
"""
return self.to_planar().add_to_archive(builder)

def split(self, split_size_or_sections: [int], dim: int) -> "[QuantizedTensor]":
from ..ops import split

return split(self, split_size_or_sections, dim)


@register_inference_tensor
class PlanarQuantizedTensor(QuantizedTensor):
Expand Down Expand Up @@ -764,12 +769,14 @@ def __init__(
assert shard_dim is None or (shard_dim >= 0 and len(ts[0].shape) > shard_dim)
super().__init__(name=name, shape=shape, shard_dim=shard_dim)
self._shards: tuple[DefaultPrimitiveTensor] = tuple(
DefaultPrimitiveTensor(
name=f"{name}.shard.{i}",
data=t,
(
DefaultPrimitiveTensor(
name=f"{name}.shard.{i}",
data=t,
)
if isinstance(t, torch.Tensor)
else t
)
if isinstance(t, torch.Tensor)
else t
for i, t in enumerate(ts)
)

Expand Down Expand Up @@ -941,7 +948,7 @@ def __init__(
will be split along dimension `shard_dim` into `shard_count`
number of pieces.
"""
if isinstance(ts, torch.Tensor):
if isinstance(ts, torch.Tensor) or isinstance(ts, InferenceTensor):
from ..ops import transfer_to_logical_device

assert shard_count is not None
Expand Down Expand Up @@ -1082,12 +1089,14 @@ def __init__(
assert shape == list(shard.shape)

self._shards: tuple[DefaultPrimitiveTensor] = tuple(
DefaultPrimitiveTensor(
name=f"{name}.shard.{i}",
data=t,
(
DefaultPrimitiveTensor(
name=f"{name}.shard.{i}",
data=t,
)
if isinstance(t, torch.Tensor)
else t
)
if isinstance(t, torch.Tensor)
else t
for i, t in enumerate(ts)
)

Expand Down
20 changes: 0 additions & 20 deletions sharktank/tests/ops/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,26 +194,6 @@ def testTorchImplTransposedQuantizedRHS_BlockScaledLayout(self):
ops.custom_impls.matmul_generic_tensor_block_scaled,
)

def testTorchImplTransposedQuantizedRHS_BlockScaledOffsetI4(self):
ops._registry._test_enable_last_op_dispatch(True)
a_dtype = torch.float32
d_dtype = torch.float32
ref_dtype = torch.float32
a = torch.rand([4, 16, 3200], dtype=a_dtype) / 256.0
d = torch.rand([3200, 100, 1], dtype=d_dtype) / 256.0
qs = (torch.rand([3200, 100, 16], dtype=ref_dtype) * 255.0).to(torch.uint8)
m = torch.rand([3200, 100, 1], dtype=d_dtype) + 16.0
rhs_pqt = PlanarQuantizedTensor(
shape=[3200, 3200],
layout=BlockScaledI4Layout([3200, 3200], d, qs, m=m, signed=False),
)
result = ops.matmul(a, rhs_pqt, transpose_rhs=True)
# Just verifying dispatch. Numerics are tested at the kernel level.
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.custom_impls.matmul_generic_tensor_block_scaled_i4,
)

# TODO: mmt_super_block_scaled_offset_q4_unsigned


Expand Down
78 changes: 77 additions & 1 deletion sharktank/tests/ops/sharded_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def testAllGather(self):

sharded = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards)
actual_result = ops.all_gather(sharded)

for shard in actual_result.shards:
torch.testing.assert_close(shard.as_torch(), expected_result)

Expand Down Expand Up @@ -770,6 +769,83 @@ def testSameSplitLhsAndRhsBatchDim(self):
actual_result = unbox_tensor(ops.unshard(sharded_result))
torch.testing.assert_close(actual_result, expected_result)

def testTranposedQuantizedRHSSharded_BlockScaledOffsetI4(self):
ops._registry._test_enable_last_op_dispatch(True)
a_dtype = torch.float32
d_dtype = torch.float32
ref_dtype = torch.float32
a = torch.rand([4, 16, 3200], dtype=a_dtype) / 256.0
d = torch.rand([3200, 100, 1], dtype=d_dtype) / 256.0
qs = (torch.rand([3200, 100, 16], dtype=ref_dtype) * 255.0).to(torch.uint8)
m = torch.rand([3200, 100, 1], dtype=d_dtype) + 16.0
rhs_pqt = PlanarQuantizedTensor(
shape=[3200, 3200],
layout=BlockScaledI4Layout([3200, 3200], d, qs, m=m, signed=False),
)
expected_result = ops.matmul(a, rhs_pqt, transpose_rhs=True)

shard_count = 2
rhs_pqt_sharded = SplitPrimitiveTensor(
shard_dim=0, ts=rhs_pqt, shard_count=shard_count
)

sharded_result = ops.matmul(a, rhs_pqt_sharded, transpose_rhs=True)
actual_result = ops.sharded_cat(sharded_result)

torch.testing.assert_close(actual_result, expected_result)

def testTorchImplTransposedQuantizedRHSSharded_BlockScaledLayout(self):
ops._registry._test_enable_last_op_dispatch(True)
a_dtype = torch.float32
d_dtype = torch.float32
ref_dtype = torch.float32
a = torch.rand([4, 16, 3200], dtype=a_dtype) * 64
d = torch.rand([3200, 100, 1], dtype=d_dtype) * 64
qs = (torch.rand([3200, 100, 32], dtype=ref_dtype) * 32.0).to(torch.int8)
rhs_pqt = PlanarQuantizedTensor(
shape=[3200, 3200], layout=BlockScaledLayout([3200, 3200], d, qs)
)
expected_result = ops.matmul(a, rhs_pqt, transpose_rhs=True)

shard_count = 2
rhs_pqt_sharded = SplitPrimitiveTensor(
shard_dim=0, ts=rhs_pqt, shard_count=shard_count
)

sharded_result = ops.matmul(a, rhs_pqt_sharded, transpose_rhs=True)
actual_result = ops.sharded_cat(sharded_result)

torch.testing.assert_close(actual_result, expected_result)

def testTorchImplTransposedQuantizedRHSSharded_TensorScaledLayout(self):
ops._registry._test_enable_last_op_dispatch(True)
a_dtype = torch.float32
d_dtype = torch.float32
ref_dtype = torch.float32
a = torch.rand([4, 16, 3200], dtype=a_dtype) * 64
d = torch.tensor(5.1, dtype=d_dtype) # torch.rand([3200], dtype=d_dtype)
qs = (torch.rand([3200, 3200], dtype=ref_dtype) * 32.0).to(torch.int8)
m = torch.tensor(
16.0, dtype=d_dtype
) # torch.rand([3200], dtype=d_dtype) + 16.0
rhs_pqt = PlanarQuantizedTensor(
shape=[3200, 3200],
layout=TensorScaledLayout(shape=[3200, 3200], d=d, qs=qs, m=m),
)
print("a shape:, ", a.shape)
print("rhs_pqt.shape: ", rhs_pqt.shape)
expected_result = ops.matmul(a, rhs_pqt, transpose_rhs=True)

shard_count = 2
rhs_pqt_sharded = SplitPrimitiveTensor(
shard_dim=0, ts=rhs_pqt, shard_count=shard_count
)

sharded_result = ops.matmul(a, rhs_pqt_sharded, transpose_rhs=True)
actual_result = ops.sharded_cat(sharded_result)

torch.testing.assert_close(actual_result, expected_result)


class ReplicateTest(unittest.TestCase):
def testReplicateReplicated(self):
Expand Down

0 comments on commit f7e9a77

Please sign in to comment.