Skip to content

Commit fd598b0

Browse files
authored
Enable Float4_e2m1fn_x2 to Float4_e2m1fn_x2 copy (#2310)
To solve #2305. This PR adds support for copying tensors with the `Float4_e2m1fn_x2` data type on XPU devices.
1 parent 1e69f40 commit fd598b0

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

src/ATen/native/xpu/sycl/CopyKernel.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,16 @@ void float8_copy_kernel_xpu(TensorIteratorBase& iter) {
111111
}
112112
}
113113

114+
void float4_copy_kernel_xpu(TensorIteratorBase& iter) {
115+
ScalarType src_dtype = iter.dtype(1);
116+
117+
if (src_dtype == kFloat4_e2m1fn_x2) {
118+
gpu_kernel_nocast(iter, CopyScalarFunc<Float4_e2m1fn_x2>());
119+
} else {
120+
TORCH_CHECK(false, "Copy from ", src_dtype, " to Float4_e2m1fn_x2 has not been supported.");
121+
}
122+
}
123+
114124
void copy_kernel(TensorIteratorBase& iter) {
115125
ScalarType dtype = iter.common_dtype();
116126
if (isQIntType(dtype)) {
@@ -119,6 +129,8 @@ void copy_kernel(TensorIteratorBase& iter) {
119129
});
120130
} else if (isFloat8Type(iter.dtype(0))) {
121131
float8_copy_kernel_xpu(iter);
132+
} else if (iter.dtype(0) == kFloat4_e2m1fn_x2) {
133+
float4_copy_kernel_xpu(iter);
122134
} else {
123135
AT_DISPATCH_V2(
124136
dtype,

test/regressions/test_copy.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
class TestSimpleCopy(TestCase):
1515
@dtypes(*float8_types_and(torch.float8_e8m0fnu, torch.float32))
1616
def test_copy_and_clone(self, dtype):
17-
a_cpu = torch.randn(16, 64, 28, 28)
18-
b_cpu = torch.randn(16, 64, 28, 28)
17+
a_cpu = torch.randn(16, 64, 28, 28).to(dtype)
18+
b_cpu = torch.randn(16, 64, 28, 28).to(dtype)
1919
a_xpu = a_cpu.to(xpu_device)
2020
b_xpu = b_cpu.to(xpu_device)
2121
# naive
@@ -27,6 +27,22 @@ def test_copy_and_clone(self, dtype):
2727
b_xpu = a_xpu.clone(memory_format=torch.channels_last)
2828
self.assertEqual(b_cpu, b_xpu.to(cpu_device))
2929

30+
def test_copy_and_clone_float4(self):
31+
# Float4_e2m1fn_x2 copy is not implemented by CPU
32+
a_cpu = torch.randn(16, 64, 28, 28).to(torch.uint8)
33+
b_cpu = torch.randn(16, 64, 28, 28).to(torch.uint8)
34+
a_xpu = a_cpu.to(xpu_device).view(torch.float4_e2m1fn_x2)
35+
b_xpu = b_cpu.to(xpu_device).view(torch.float4_e2m1fn_x2)
36+
37+
b_cpu.copy_(a_cpu)
38+
b_xpu.copy_(a_xpu)
39+
# Float4_e2m1fn_x2 compare is not implemented CPU
40+
self.assertEqual(b_cpu, b_xpu.view(torch.uint8).to(cpu_device))
41+
42+
b_cpu = a_cpu.clone(memory_format=torch.channels_last)
43+
b_xpu = a_xpu.clone(memory_format=torch.channels_last)
44+
self.assertEqual(b_cpu, b_xpu.view(torch.uint8).to(cpu_device))
45+
3046

3147
instantiate_device_type_tests(TestSimpleCopy, globals(), only_for="xpu", allow_xpu=True)
3248

0 commit comments

Comments
 (0)