Skip to content

Commit 48cd4bb

Browse files
committed
Add UT for FP4 Copy
1 parent 44b32df commit 48cd4bb

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

test/regressions/test_copy.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
class TestSimpleCopy(TestCase):
1515
@dtypes(*float8_types_and(torch.float8_e8m0fnu, torch.float32))
1616
def test_copy_and_clone(self, dtype):
17-
a_cpu = torch.randn(16, 64, 28, 28)
18-
b_cpu = torch.randn(16, 64, 28, 28)
17+
a_cpu = torch.randn(16, 64, 28, 28).to(dtype)
18+
b_cpu = torch.randn(16, 64, 28, 28).to(dtype)
1919
a_xpu = a_cpu.to(xpu_device)
2020
b_xpu = b_cpu.to(xpu_device)
2121
# naive
@@ -27,6 +27,22 @@ def test_copy_and_clone(self, dtype):
2727
b_xpu = a_xpu.clone(memory_format=torch.channels_last)
2828
self.assertEqual(b_cpu, b_xpu.to(cpu_device))
2929

30+
def test_copy_and_clone_float4(self):
31+
# Float4_e2m1fn_x2 copy is not implemented by CPU
32+
a_cpu = torch.randn(16, 64, 28, 28).to(torch.uint8)
33+
b_cpu = torch.randn(16, 64, 28, 28).to(torch.uint8)
34+
a_xpu = a_cpu.to(xpu_device).view(torch.float4_e2m1fn_x2)
35+
b_xpu = b_cpu.to(xpu_device).view(torch.float4_e2m1fn_x2)
36+
37+
b_cpu.copy_(a_cpu)
38+
b_xpu.copy_(a_xpu)
39+
# Float4_e2m1fn_x2 compare is not implemented CPU
40+
self.assertEqual(b_cpu, b_xpu.view(torch.uint8).to(cpu_device))
41+
42+
b_cpu = a_cpu.clone(memory_format=torch.channels_last)
43+
b_xpu = a_xpu.clone(memory_format=torch.channels_last)
44+
self.assertEqual(b_cpu, b_xpu.view(torch.uint8).to(cpu_device))
45+
3046

3147
instantiate_device_type_tests(TestSimpleCopy, globals(), only_for="xpu", allow_xpu=True)
3248

0 commit comments

Comments
 (0)