diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index d7ed7b8e8..b04bf64b6 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -438,8 +438,8 @@ def to_default(tensor: Tensor, *args, **kwargs): return unbox_tensor(tensor).to(*args, **kwargs) -@transfer_to_logical_device.override(Tensor) -def transfer_to_logical_device_default(tensor: Tensor, ordinal: int): +@transfer_to_logical_device.override(AllOfType(AnyTensor, QuantizedTensor)) +def transfer_to_logical_device_default(tensor, ordinal: int): return iree.turbine.ops.iree.transfer_to_logical_device( f"{ordinal}", unbox_tensor(tensor) ) diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index 0dd0d2ae7..298349713 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -15,6 +15,8 @@ AnyTensor, DefaultPrimitiveTensor, InferenceTensor, + QuantizedTensor, + PlanarQuantizedTensor, PrimitiveTensor, ReplicatedTensor, ShardedTensor, @@ -28,6 +30,8 @@ from .signatures import * from .shape import broadcast_dims, broadcast_dim, unbroadcast_dim from ..utils import longest_equal_range +from ..utils.math import ceildiv +from sharktank.types.tensors import REGISTERED_LAYOUT_CLASSES @all_gather.override(SplitPrimitiveTensor) @@ -1264,3 +1268,82 @@ def view_split(tensor: SplitPrimitiveTensor, shape: List[int]) -> SplitPrimitive res = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards) assert math.prod(res.shape) == math.prod(tensor.shape) return res + + +@split.override(QuantizedTensor) +def split_QuantizedTensor(tensor: QuantizedTensor, split_size_or_sections, dim): + tensors = [] + unpacked = tensor.unpack() + num_outputs = ceildiv(unpacked._qs.shape[dim], split_size_or_sections) + new_shape = unpacked._shape + new_shape[dim] = split_size_or_sections + new_qs = torch.split(unpacked._qs, split_size_or_sections, dim) + if unpacked._d.ndim > 0: + new_d = torch.split(unpacked._d, split_size_or_sections, dim) + if unpacked.serialized_name() == "SuperBlockOffsetScaled_4_6_Layout": + new_dmin = torch.split(unpacked._dmin, split_size_or_sections, dim) + new_sb_scales_high = torch.split( + unpacked._sb_scales_high, split_size_or_sections, dim + ) + new_sb_scales_low = torch.split( + unpacked._sb_scales_low, split_size_or_sections, dim + ) + new_sb_mins_high = torch.split( + unpacked._sb_mins_high, split_size_or_sections, dim + ) + new_sb_mins_low = torch.split( + unpacked._sb_mins_low, split_size_or_sections, dim + ) + for i in range(num_outputs): + layout_clazz = REGISTERED_LAYOUT_CLASSES[unpacked.serialized_name()] + new_layout = layout_clazz( + shape=new_shape, + d=new_d[i], + dmin=new_dmin[i], + sb_scales_high=new_sb_scales_high[i], + sb_scales_low=new_sb_scales_low[i], + sb_mins_high=new_sb_mins_high[i], + sb_mins_low=new_sb_mins_low[i], + qs=new_qs[i], + ) + new_tensor = tensor.__class__ + new_tensor_layout = new_layout.create( + new_layout.shape, new_layout.metadata, new_layout.planes + ) + new_tensor = tensor.__class__( + shape=new_shape, layout=new_tensor_layout, name=tensor._name + str(i) + ) + tensors.append(new_tensor) + else: + if split_size_or_sections > unpacked._qs.shape[dim]: + raise ValueError("split size greater than tensor dim") + + if unpacked._m is not None: + if unpacked._m.ndim > 0: + new_m = torch.split(unpacked._m, split_size_or_sections, dim) + for i in range(num_outputs): + layout_clazz = REGISTERED_LAYOUT_CLASSES[unpacked.serialized_name()] + if unpacked._m is not None: + if unpacked._d.ndim > 0: + new_layout = layout_clazz( + shape=new_shape, d=new_d[i], qs=new_qs[i], m=new_m[i] + ) + else: + new_layout = layout_clazz( + shape=new_shape, d=unpacked._d, qs=new_qs[i], m=unpacked._m + ) + else: + if unpacked._d.ndim > 0: + new_layout = layout_clazz(shape=new_shape, d=new_d[i], qs=new_qs[i]) + else: + new_layout = layout_clazz( + shape=new_shape, d=unpacked._d, qs=new_qs[i] + ) + new_tensor_layout = new_layout.create( + new_layout.shape, new_layout.metadata, new_layout.planes + ) + new_tensor = tensor.__class__( + shape=new_shape, layout=new_tensor_layout, name=tensor._name + str(i) + ) + tensors.append(new_tensor) + return tensors diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index 89d4309ee..ed33740fe 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -11,7 +11,15 @@ import torch import numbers from torch import Tensor, dtype -from ..types import AnyTensor, ShardedTensor, Theta, sharding, InferenceTensor +from ..types import ( + AnyTensor, + ShardedTensor, + Theta, + sharding, + InferenceTensor, + QuantizedTensor, + PlanarQuantizedTensor, +) from numbers import Number from ._registry import * @@ -59,6 +67,7 @@ "unshard", "unsqueeze", "view", + "split", ] IntOrSequenceInt = Union[int, Sequence[int]] @@ -101,8 +110,9 @@ def _all_reduce_trampoline(d: SignatureDispatcher, tensor: AnyTensor): @overridable -def cat(tensors: Tuple[AnyTensor, ...] | List[AnyTensor], dim: int = 0) -> AnyTensor: - ... +def cat( + tensors: Tuple[AnyTensor, ...] | List[AnyTensor], dim: int = 0 +) -> AnyTensor: ... @cat.trampoline @@ -919,8 +929,7 @@ def _sharded_cat_trampoline(d: SignatureDispatcher, maybe_sharded: AnyTensor): @overridable -def sharded_sum(maybe_sharded: AnyTensor): - ... +def sharded_sum(maybe_sharded: AnyTensor): ... @sharded_sum.trampoline @@ -976,14 +985,18 @@ def _to_trampoline(d: SignatureDispatcher, tensor: AnyTensor, *args, **kwargs): @overridable -def transfer_to_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor: +def transfer_to_logical_device( + tensor: Union[AnyTensor, QuantizedTensor, PlanarQuantizedTensor], ordinal: int +) -> Union[AnyTensor, QuantizedTensor, PlanarQuantizedTensor]: """Transfer the tensor to a device with ordinal `ordinal`.""" ... @transfer_to_logical_device.trampoline def _transfer_to_logical_device_trampoline( - d: SignatureDispatcher, tensor: AnyTensor, ordinal: int + d: SignatureDispatcher, + tensor: Union[AnyTensor, QuantizedTensor, PlanarQuantizedTensor], + ordinal: int, ): tensors = (tensor,) for override in d.find_overrides(tensors): @@ -1085,3 +1098,27 @@ def _view_trampoline( return override, result else: d.fail(tensors) + + +@overridable +def split( + tensor: QuantizedTensor, split_size_or_sections: List[int], dim: int +) -> [QuantizedTensor]: + """See torch.Tensor.split""" + ... + + +@split.trampoline +def _split_trampoline( + d: SignatureDispatcher, + tensor: QuantizedTensor, + split_size_or_sections: List[int], + dim: int, +) -> [QuantizedTensor]: + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor, split_size_or_sections, dim) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) diff --git a/sharktank/sharktank/types/sharding.py b/sharktank/sharktank/types/sharding.py index 81d2f31a5..e6c852ef5 100644 --- a/sharktank/sharktank/types/sharding.py +++ b/sharktank/sharktank/types/sharding.py @@ -60,15 +60,17 @@ def __init__(self, *args, **kwargs): for k, v in d.items(): d[k] = tree.map_nodes( tree=v, - f=lambda x: x - if isinstance( - x, - ( - TensorSharding, - ThetaSharding, - ), - ) - else ThetaSharding(x), + f=lambda x: ( + x + if isinstance( + x, + ( + TensorSharding, + ThetaSharding, + ), + ) + else ThetaSharding(x) + ), ) super().__init__(d) diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 226ffd777..810414a73 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -62,8 +62,7 @@ class QuantizedLayout(ABC): @abstractmethod - def dequant(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: - ... + def dequant(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: ... @classmethod @abstractmethod @@ -78,8 +77,7 @@ def create( shape: list[int], metadata: Optional[dict[str, MetaDataValueType]], planes: dict[str, torch.Tensor], - ) -> "QuantizedLayout": - ... + ) -> "QuantizedLayout": ... @property @abstractmethod @@ -559,8 +557,7 @@ def __init__( self.layout_type = layout_type @abstractmethod - def unpack(self) -> QuantizedLayoutT: - ... + def unpack(self) -> QuantizedLayoutT: ... def to_planar(self) -> "PlanarQuantizedTensor": """Converts this QuantizedTensor to a generic planar form. @@ -581,6 +578,11 @@ def add_to_archive(self, builder: ShardedArchiveBuilder) -> InferenceTensorMetad """ return self.to_planar().add_to_archive(builder) + def split(self, split_size_or_sections: [int], dim: int) -> "[QuantizedTensor]": + from ..ops import split + + return split(self, split_size_or_sections, dim) + @register_inference_tensor class PlanarQuantizedTensor(QuantizedTensor): @@ -717,8 +719,7 @@ def __init__( @property @abstractmethod - def shard_count(self) -> int: - ... + def shard_count(self) -> int: ... @property @abstractmethod @@ -764,12 +765,14 @@ def __init__( assert shard_dim is None or (shard_dim >= 0 and len(ts[0].shape) > shard_dim) super().__init__(name=name, shape=shape, shard_dim=shard_dim) self._shards: tuple[DefaultPrimitiveTensor] = tuple( - DefaultPrimitiveTensor( - name=f"{name}.shard.{i}", - data=t, + ( + DefaultPrimitiveTensor( + name=f"{name}.shard.{i}", + data=t, + ) + if isinstance(t, torch.Tensor) + else t ) - if isinstance(t, torch.Tensor) - else t for i, t in enumerate(ts) ) @@ -941,7 +944,7 @@ def __init__( will be split along dimension `shard_dim` into `shard_count` number of pieces. """ - if isinstance(ts, torch.Tensor): + if isinstance(ts, torch.Tensor) or isinstance(ts, InferenceTensor): from ..ops import transfer_to_logical_device assert shard_count is not None @@ -1082,12 +1085,14 @@ def __init__( assert shape == list(shard.shape) self._shards: tuple[DefaultPrimitiveTensor] = tuple( - DefaultPrimitiveTensor( - name=f"{name}.shard.{i}", - data=t, + ( + DefaultPrimitiveTensor( + name=f"{name}.shard.{i}", + data=t, + ) + if isinstance(t, torch.Tensor) + else t ) - if isinstance(t, torch.Tensor) - else t for i, t in enumerate(ts) ) diff --git a/sharktank/tests/ops/ops_test.py b/sharktank/tests/ops/ops_test.py index 1c1c06ed3..b8df61e40 100644 --- a/sharktank/tests/ops/ops_test.py +++ b/sharktank/tests/ops/ops_test.py @@ -194,26 +194,6 @@ def testTorchImplTransposedQuantizedRHS_BlockScaledLayout(self): ops.custom_impls.matmul_generic_tensor_block_scaled, ) - def testTorchImplTransposedQuantizedRHS_BlockScaledOffsetI4(self): - ops._registry._test_enable_last_op_dispatch(True) - a_dtype = torch.float32 - d_dtype = torch.float32 - ref_dtype = torch.float32 - a = torch.rand([4, 16, 3200], dtype=a_dtype) / 256.0 - d = torch.rand([3200, 100, 1], dtype=d_dtype) / 256.0 - qs = (torch.rand([3200, 100, 16], dtype=ref_dtype) * 255.0).to(torch.uint8) - m = torch.rand([3200, 100, 1], dtype=d_dtype) + 16.0 - rhs_pqt = PlanarQuantizedTensor( - shape=[3200, 3200], - layout=BlockScaledI4Layout([3200, 3200], d, qs, m=m, signed=False), - ) - result = ops.matmul(a, rhs_pqt, transpose_rhs=True) - # Just verifying dispatch. Numerics are tested at the kernel level. - self.assertIs( - ops._registry._test_get_last_op_dispatch(), - ops.custom_impls.matmul_generic_tensor_block_scaled_i4, - ) - # TODO: mmt_super_block_scaled_offset_q4_unsigned diff --git a/sharktank/tests/ops/sharded_test.py b/sharktank/tests/ops/sharded_test.py index c400bfa3c..77a8a55fa 100644 --- a/sharktank/tests/ops/sharded_test.py +++ b/sharktank/tests/ops/sharded_test.py @@ -27,7 +27,6 @@ def testAllGather(self): sharded = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards) actual_result = ops.all_gather(sharded) - for shard in actual_result.shards: torch.testing.assert_close(shard.as_torch(), expected_result) @@ -770,6 +769,83 @@ def testSameSplitLhsAndRhsBatchDim(self): actual_result = unbox_tensor(ops.unshard(sharded_result)) torch.testing.assert_close(actual_result, expected_result) + def testTranposedQuantizedRHSSharded_BlockScaledOffsetI4(self): + ops._registry._test_enable_last_op_dispatch(True) + a_dtype = torch.float32 + d_dtype = torch.float32 + ref_dtype = torch.float32 + a = torch.rand([4, 16, 3200], dtype=a_dtype) / 256.0 + d = torch.rand([3200, 100, 1], dtype=d_dtype) / 256.0 + qs = (torch.rand([3200, 100, 16], dtype=ref_dtype) * 255.0).to(torch.uint8) + m = torch.rand([3200, 100, 1], dtype=d_dtype) + 16.0 + rhs_pqt = PlanarQuantizedTensor( + shape=[3200, 3200], + layout=BlockScaledI4Layout([3200, 3200], d, qs, m=m, signed=False), + ) + expected_result = ops.matmul(a, rhs_pqt, transpose_rhs=True) + + shard_count = 2 + rhs_pqt_sharded = SplitPrimitiveTensor( + shard_dim=0, ts=rhs_pqt, shard_count=shard_count + ) + + sharded_result = ops.matmul(a, rhs_pqt_sharded, transpose_rhs=True) + actual_result = ops.sharded_cat(sharded_result) + + torch.testing.assert_close(actual_result, expected_result) + + def testTorchImplTransposedQuantizedRHSSharded_BlockScaledLayout(self): + ops._registry._test_enable_last_op_dispatch(True) + a_dtype = torch.float32 + d_dtype = torch.float32 + ref_dtype = torch.float32 + a = torch.rand([4, 16, 3200], dtype=a_dtype) * 64 + d = torch.rand([3200, 100, 1], dtype=d_dtype) * 64 + qs = (torch.rand([3200, 100, 32], dtype=ref_dtype) * 32.0).to(torch.int8) + rhs_pqt = PlanarQuantizedTensor( + shape=[3200, 3200], layout=BlockScaledLayout([3200, 3200], d, qs) + ) + expected_result = ops.matmul(a, rhs_pqt, transpose_rhs=True) + + shard_count = 2 + rhs_pqt_sharded = SplitPrimitiveTensor( + shard_dim=0, ts=rhs_pqt, shard_count=shard_count + ) + + sharded_result = ops.matmul(a, rhs_pqt_sharded, transpose_rhs=True) + actual_result = ops.sharded_cat(sharded_result) + + torch.testing.assert_close(actual_result, expected_result) + + def testTorchImplTransposedQuantizedRHSSharded_TensorScaledLayout(self): + ops._registry._test_enable_last_op_dispatch(True) + a_dtype = torch.float32 + d_dtype = torch.float32 + ref_dtype = torch.float32 + a = torch.rand([4, 16, 3200], dtype=a_dtype) * 64 + d = torch.tensor(5.1, dtype=d_dtype) # torch.rand([3200], dtype=d_dtype) + qs = (torch.rand([3200, 3200], dtype=ref_dtype) * 32.0).to(torch.int8) + m = torch.tensor( + 16.0, dtype=d_dtype + ) # torch.rand([3200], dtype=d_dtype) + 16.0 + rhs_pqt = PlanarQuantizedTensor( + shape=[3200, 3200], + layout=TensorScaledLayout(shape=[3200, 3200], d=d, qs=qs, m=m), + ) + print("a shape:, ", a.shape) + print("rhs_pqt.shape: ", rhs_pqt.shape) + expected_result = ops.matmul(a, rhs_pqt, transpose_rhs=True) + + shard_count = 2 + rhs_pqt_sharded = SplitPrimitiveTensor( + shard_dim=0, ts=rhs_pqt, shard_count=shard_count + ) + + sharded_result = ops.matmul(a, rhs_pqt_sharded, transpose_rhs=True) + actual_result = ops.sharded_cat(sharded_result) + + torch.testing.assert_close(actual_result, expected_result) + class ReplicateTest(unittest.TestCase): def testReplicateReplicated(self):