Skip to content

[DRAFT] Enable CPU data layout convert to XPU #2441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions torchao/dtypes/uintx/int4_cpu_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class Int4CPULayout(Layout):
pass


from torchao.dtypes.uintx.int4_xpu_layout import Int4XPUAQTTensorImpl

@register_layout(Int4CPULayout)
class Int4CPUAQTTensorImpl(AQTTensorImpl):
"""TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only,
Expand Down Expand Up @@ -148,10 +150,16 @@ def from_plain(
def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
device = kwargs["device"]
if not is_device(torch.device(self.device).type, device):
if self.device.type == "xpu":
from torchao.dtypes import Int4XPULayout
int_data, scale, zero_point = self.get_plain()
int_data, scale, zero_point = int_data.to(self.device), scale.to(self.device), zero_point.to(self.device)
return Int4XPUAQTTensorImpl.from_plain(int_data, scale, zero_point, _layout=Int4XPULayout())
elif not is_device(torch.device(self.device).type, device):
raise ValueError(
f"Int4CPUAQTTensorImpl does not support conversion from {self.device} to {device}"
)

return self.__class__(
self.packed_weight.to(device),
self.scale_and_zero.to(device),
Expand Down Expand Up @@ -241,6 +249,10 @@ def block_size(self):
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros

if self.device.type != "cpu":
self.scale_and_zero = self.scale_and_zero.to("cpu")
self.packed_weight = self.packed_weight.to("cpu")

scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)

cur_shape = self.shape
Expand All @@ -249,7 +261,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
eye_shape = original_shape[1]
groupsize = int(original_shape[1] / scale.shape[-2])
block_size = (1, groupsize)
device = self.device
device = torch.device("cpu")
original_dtype = self.scale_and_zero.dtype
target_dtype = torch.int32
quant_min = 0
Expand Down