Skip to content

Commit

Permalink
Merge branch 'perplexity-test' of https://github.com/nod-ai/SHARK-Pla…
Browse files Browse the repository at this point in the history
…tform into perplexity-test
  • Loading branch information
archana-ramalingam committed Oct 19, 2024
2 parents 726ec7e + 7d440e3 commit 71a2f1e
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 30 deletions.
3 changes: 2 additions & 1 deletion sharktank/sharktank/kernels/batch_matmul_transpose_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
spec_sig = f"L{a_ident}_R{b_ident}"
template_file = "batch_matmul_transpose_b.mlir"
target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}"

cst_zero = "0." if "f" in str(accum_type) else "0"
# Template params.
c_asm_type = f"tensor<{'x'.join('?' if d is None else str(d) for d in result_desc.spec_dims)}x{accum_type}>"

Expand All @@ -93,5 +93,6 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
b_asm_type=b_asm_type,
c_asm_type=c_asm_type,
dtype=str(accum_type),
cst_zero=cst_zero,
)
kb.yield_results(*call_function(target_function, *kb.arg_bindings))
3 changes: 3 additions & 0 deletions sharktank/sharktank/kernels/conv_2d_nchw_fchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
(torch.int16, torch.int16, "torch.int16"): torch.int16,
(torch.int16, torch.int16, "torch.int32"): torch.int32,
# Legal fp types.
(torch.float8_e4m3fnuz, torch.float8_e4m3fnuz, "torch.float16"): torch.float16,
(torch.float8_e4m3fnuz, torch.float8_e4m3fnuz, "torch.float32"): torch.float32,
(torch.float16, torch.float16, "torch.float16"): torch.float16,
(torch.float16, torch.float16, "torch.float32"): torch.float32,
(torch.float32, torch.float32, "torch.float32"): torch.float32,
Expand All @@ -33,6 +35,7 @@
torch.int8: "i8",
torch.int16: "i16",
torch.int32: "i32",
torch.float8_e4m3fnuz: "f8E4M3FNUZ",
torch.float16: "f16",
torch.float32: "f32",
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

// !cst_zero = {{cst_zero}}
!dtype = {{dtype}}
!a_tensor_type = {{a_asm_type}}
!b_tensor_type = {{b_asm_type}}
Expand All @@ -15,7 +16,8 @@ module {
util.func private @sharktank_batch_matmul_transpose_b_{{spec_sig}}(
%a: !a_tensor_type, %b: !b_tensor_type)
-> !c_tensor_type {
%zero = arith.constant 0: !dtype
// %zero = arith.constant !cst_zero: !dtype
%zero = arith.constant {{cst_zero}}: !dtype
%c0 = arith.constant 0: index
%c1 = arith.constant 1: index
%batch_dim = tensor.dim %a, %c0 : !a_tensor_type // b, m, k
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def quantize_bias(
bias_scale = 1.0 / (input_scale * weight_scale)
bias_quantizer = StaticScaledQuantizer(
scale=bias_scale,
dtype=torch.int32,
dtype=torch.int32 if quantization_dtype == torch.int8 else torch.float16,
disable_saturate=True,
)
bias_quant = bias_quantizer.quantize(bias, name=bias_name)
Expand Down
29 changes: 17 additions & 12 deletions sharktank/sharktank/ops/qconv_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)


def qconv2d_tensor_scaled_integer(
def qconv2d_tensor_scaled(
input: QuantizedTensor,
weight: QuantizedTensor,
bias: Optional[AnyTensor] = None,
Expand Down Expand Up @@ -59,12 +59,16 @@ def qconv2d_tensor_scaled_integer(
input_layout: TensorScaledLayout = input.unpack()
weight_layout: TensorScaledLayout = weight.unpack()

# Only handle integer quantizations.
# # Handle integer and fp8 quantizations.
if (
input_layout.qs.dtype.is_floating_point
or weight_layout.qs.dtype.is_floating_point
):
return NotImplemented
if (
input_layout.qs.dtype != torch.float8_e4m3fnuz
or weight_layout.qs.dtype != torch.float8_e4m3fnuz
):
return NotImplemented

# Bias is both optional and may either be quantized or fp.
bias_qs = None
Expand All @@ -85,7 +89,10 @@ def qconv2d_tensor_scaled_integer(

# Alias components (d=scale, qs=quantized samples, m=offset).
if accum_dtype is None:
accum_dtype = torch.int32
if weight_layout.qs.dtype.is_floating_point:
accum_dtype = torch.float32
else:
accum_dtype = torch.int32
input_d = input_layout.d
input_dtype = input_layout.dtype
input_qs = input_layout.qs
Expand Down Expand Up @@ -114,7 +121,7 @@ def qconv2d_tensor_scaled_integer(
dilation = _expand_int_to_2_tuple(dilation)
extended_padding_list = [item for item in padding for _ in range(2)]
padded_input = _pad_last_2d(input_qs, extended_padding_list)
y_qs = _invoke_int32_conv2d(
y_qs = _invoke_conv2d_kernel(
padded_input,
weight_qs,
bias_qs.to(accum_dtype) if bias_qs is not None else None,
Expand Down Expand Up @@ -145,7 +152,7 @@ def qconv2d_tensor_scaled_integer(
weight_offset_fix = torch.sum(
padded_input, dim=1, keepdim=True, dtype=accum_dtype
)
weight_offset_fix = _invoke_int32_pooling_sum(
weight_offset_fix = _invoke_pooling_sum_kernel(
weight_offset_fix,
[weight_qs.shape[2], weight_qs.shape[3]],
stride,
Expand Down Expand Up @@ -188,13 +195,11 @@ def qconv2d_tensor_scaled_integer(
return y


conv2d.override(QuantizedTensor, QuantizedTensor)(qconv2d_tensor_scaled_integer)
conv2d.override(QuantizedTensor, QuantizedTensor, AnyTensor)(
qconv2d_tensor_scaled_integer
)
conv2d.override(QuantizedTensor, QuantizedTensor)(qconv2d_tensor_scaled)
conv2d.override(QuantizedTensor, QuantizedTensor, AnyTensor)(qconv2d_tensor_scaled)


def _invoke_int32_conv2d(input, weight, bias, stride, dilation, *, accum_dtype):
def _invoke_conv2d_kernel(input, weight, bias, stride, dilation, *, accum_dtype):
"""Does a low level invocation of a conv2d integer kernel on an explicitly padded input.
This presumes that the stride/padding/dilation have already been normalized
Expand Down Expand Up @@ -233,7 +238,7 @@ def _invoke_int32_conv2d(input, weight, bias, stride, dilation, *, accum_dtype):
return y_qs


def _invoke_int32_pooling_sum(input, kernel_size, stride, dilation, *, accum_dtype):
def _invoke_pooling_sum_kernel(input, kernel_size, stride, dilation, *, accum_dtype):
"""Invokes either a custom integer pooling sum or the built-in fp avg_pool2d
kernel on an explicitly padded input.
"""
Expand Down
35 changes: 25 additions & 10 deletions sharktank/sharktank/ops/qlinear_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from sharktank import kernels


def qlinear_tensor_scaled_integer(
def qlinear_tensor_scaled(
x: QuantizedTensor,
weight: QuantizedTensor,
bias: Optional[AnyTensor],
Expand All @@ -48,8 +48,11 @@ def qlinear_tensor_scaled_integer(
x_layout: TensorScaledLayout = x.unpack()
weight_layout: TensorScaledLayout = weight.unpack()

# Only handle integer quantizations.
if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point:
# Handle integer and fp8 quantizations.
if (
x_layout.qs.dtype != torch.float8_e4m3fnuz
and weight_layout.qs.dtype != torch.float8_e4m3fnuz
):
return NotImplemented

# Bias.
Expand All @@ -64,7 +67,10 @@ def qlinear_tensor_scaled_integer(

# Alias components (d=scale, qs=quantized samples, m=offset)
if accum_dtype is None:
accum_dtype = torch.int32
if weight_layout.qs.dtype.is_floating_point:
accum_dtype = torch.float32
else:
accum_dtype = torch.int32
x_d = x_layout.d
x_dtype = x_layout.dtype
x_qs = x_layout.qs
Expand All @@ -86,7 +92,7 @@ def qlinear_tensor_scaled_integer(
# TODO: Handle permutation that we have a kernel for.

# Fall back to automatic fusion based on integer, high precision matmul.
y_qs = _invoke_int32_mmt(x_qs, weight_qs, accum_dtype=accum_dtype)
y_qs = _invoke_mmt_kernel(x_qs, weight_qs, accum_dtype=accum_dtype)

# Offset correction. By applying the offset correction in post, it is
# set up to fuse with its consumer, which is already doing additional
Expand Down Expand Up @@ -143,10 +149,8 @@ def qlinear_tensor_scaled_integer(


# Overrload for both bias and no bias.
linear.override(QuantizedTensor, QuantizedTensor)(qlinear_tensor_scaled_integer)
linear.override(QuantizedTensor, QuantizedTensor, AnyTensor)(
qlinear_tensor_scaled_integer
)
linear.override(QuantizedTensor, QuantizedTensor)(qlinear_tensor_scaled)
linear.override(QuantizedTensor, QuantizedTensor, AnyTensor)(qlinear_tensor_scaled)


def linear_quantized_weight(
Expand All @@ -166,19 +170,30 @@ def linear_quantized_weight(
linear.override(Tensor, QuantizedTensor, AnyTensor)(linear_quantized_weight)


def _invoke_int32_mmt(lhs, rhs, *, accum_dtype):
def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype):
if debugging.flags.use_custom_iree_kernels:
# The custom kernel requires that the lhs and rhs be the same
# rank. Broadcast the rhs to match.
lhs_rank = len(lhs.shape)
rhs_rank = len(rhs.shape)
# If input to the kernel is 2D, expand the tensor to add the batch
# dimension.
if lhs_rank == 2:
lhs_size = [1] + list(lhs.shape)
lhs = lhs.unsqueeze(0).expand(lhs_size)
lhs_rank = len(lhs.shape)
if rhs_rank < lhs_rank:
assert (rhs_rank + 1) == lhs_rank
rhs_size = [lhs.shape[0]] + list(rhs.shape)
rhs = rhs.unsqueeze(0).expand(rhs_size)
rhs_rank = len(rhs.shape)
y_qs = kernels.batch_matmul_transpose_b(
lhs.to(accum_dtype), rhs.to(accum_dtype)
)
# Squeeze the batch dimension to maintain shape parity with other
# layers.
if len(y_qs.shape) > 2:
y_qs = y_qs.squeeze(0)
else:
# FP emulation.
y_qs = torch.matmul(
Expand Down
10 changes: 5 additions & 5 deletions sharktank/tests/ops/qconv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_NoBias(self):
)
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.qconv_impls.qconv2d_tensor_scaled_integer,
ops.qconv_impls.qconv2d_tensor_scaled,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
Expand Down Expand Up @@ -105,7 +105,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_FloatBias(self):
y_actual = ops.conv2d(input_q, weight_q, bias, stride=1, padding=(1, 1))
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.qconv_impls.qconv2d_tensor_scaled_integer,
ops.qconv_impls.qconv2d_tensor_scaled,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
Expand Down Expand Up @@ -147,7 +147,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_QuantizedBias(self):
)
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.qconv_impls.qconv2d_tensor_scaled_integer,
ops.qconv_impls.qconv2d_tensor_scaled,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
Expand Down Expand Up @@ -184,7 +184,7 @@ def testInputSymPerTensor_WeightSymPerTensor_NoBias(self):
)
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.qconv_impls.qconv2d_tensor_scaled_integer,
ops.qconv_impls.qconv2d_tensor_scaled,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
Expand Down Expand Up @@ -224,7 +224,7 @@ def testInputAsymPerChannel_WeightAsymPerChannel_NoBias(self):
)
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.qconv_impls.qconv2d_tensor_scaled_integer,
ops.qconv_impls.qconv2d_tensor_scaled,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
Expand Down
54 changes: 54 additions & 0 deletions sharktank/tests/ops/sharded_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,42 @@ def testSplitTensorSplitDimIsLeadingFlattenDim(self):
actual_result = ops.reshape(sharded_tensor, new_shape)
assert expected_result.is_deep_equal(actual_result)

def testSplitTensorInsertSize1DimBeforeSplitDim(self):
tensor = torch.rand(4, 5, 6, 7)
new_shape = [4, 1, 5, 6, 7]
unsharded_expected_result = torch.reshape(tensor, new_shape)
shard_dim = 2
expected_result = ops.reshard_split(
unsharded_expected_result, dim=shard_dim + 1, count=2
)
sharded_tensor = ops.reshard_split(tensor, dim=shard_dim, count=2)
actual_result = ops.reshape(sharded_tensor, new_shape)
assert expected_result.is_deep_equal(actual_result)

def testSplitTensorInsertMultipleSize1DimsBeforeSplitDim(self):
tensor = torch.rand(4, 5, 6, 7)
new_shape = [4, 1, 1, 5, 6, 7]
unsharded_expected_result = torch.reshape(tensor, new_shape)
shard_dim = 2
expected_result = ops.reshard_split(
unsharded_expected_result, dim=shard_dim + 2, count=2
)
sharded_tensor = ops.reshard_split(tensor, dim=shard_dim, count=2)
actual_result = ops.reshape(sharded_tensor, new_shape)
assert expected_result.is_deep_equal(actual_result)

def testSplitTensorInsertMultipleSize1TrailingDimsNotRightAfterSplitDim(self):
tensor = torch.rand(4, 5, 6, 7)
new_shape = [4, 5, 6, 7, 1, 1]
unsharded_expected_result = torch.reshape(tensor, new_shape)
shard_dim = 2
expected_result = ops.reshard_split(
unsharded_expected_result, dim=shard_dim, count=2
)
sharded_tensor = ops.reshard_split(tensor, dim=shard_dim, count=2)
actual_result = ops.reshape(sharded_tensor, new_shape)
assert expected_result.is_deep_equal(actual_result)

def testSplitTensorUnflattenNonSplitDim(self):
tensor = torch.rand(3, 20, 6)
new_shape = [3, 4, 5, 6]
Expand All @@ -819,6 +855,15 @@ def testSplitTensorUnflattenNonSplitDim(self):
actual_result = ops.reshape(sharded_tensor, new_shape)
assert expected_result.is_deep_equal(actual_result)

def testSplitTensorUnflattenTrailingNonSplitDim(self):
tensor = torch.rand(3, 4, 30)
new_shape = [3, 4, 5, 6]
unsharded_expected_result = torch.reshape(tensor, new_shape)
expected_result = ops.reshard_split(unsharded_expected_result, dim=1, count=2)
sharded_tensor = ops.reshard_split(tensor, dim=1, count=2)
actual_result = ops.reshape(sharded_tensor, new_shape)
assert expected_result.is_deep_equal(actual_result)

def testSplitTensorUnflattenSplitDim(self):
tensor = torch.rand(3, 20, 6)
new_shape = [3, 4, 5, 6]
Expand All @@ -828,6 +873,15 @@ def testSplitTensorUnflattenSplitDim(self):
actual_result = ops.reshape(sharded_tensor, new_shape)
assert expected_result.is_deep_equal(actual_result)

def testSplitTensorUnflattenTrailingSplitDim(self):
tensor = torch.rand(2, 3, 20)
new_shape = [2, 3, 4, 5]
unsharded_expected_result = torch.reshape(tensor, new_shape)
expected_result = ops.reshard_split(unsharded_expected_result, dim=2, count=2)
sharded_tensor = ops.reshard_split(tensor, dim=2, count=2)
actual_result = ops.reshape(sharded_tensor, new_shape)
assert expected_result.is_deep_equal(actual_result)


class ReshardSplitTest(unittest.TestCase):
def testReshardReplicated(self):
Expand Down
8 changes: 8 additions & 0 deletions shortfin/tests/host_cpu_system_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import pytest
import re
import sys

import shortfin as sf

Expand All @@ -19,6 +20,9 @@ def test_create_host_cpu_system_defaults():
assert len(ls.devices) > 0


@pytest.mark.skipif(
sys.platform == "win32", reason="Windows fatal exception: access violation"
)
def test_create_host_cpu_system_topology_nodes_all():
sc = sf.host.CPUSystemBuilder(
hostcpu_topology_nodes="all", hostcpu_topology_max_group_count=2
Expand All @@ -39,6 +43,10 @@ def test_create_host_cpu_system_topology_nodes_explicit():
assert len(ls.devices) == 2


@pytest.mark.skipif(
sys.platform == "win32",
reason="Only detecting 1 device, check config setup from env vars?",
)
def test_create_host_cpu_system_env_vars():
os.environ["SHORTFIN_HOSTCPU_TOPOLOGY_NODES"] = "0,0"
os.environ["SHORTFIN_HOSTCPU_TOPOLOGY_MAX_GROUP_COUNT"] = "2"
Expand Down

0 comments on commit 71a2f1e

Please sign in to comment.