1414class TestSimpleCopy (TestCase ):
1515 @dtypes (* float8_types_and (torch .float8_e8m0fnu , torch .float32 ))
1616 def test_copy_and_clone (self , dtype ):
17- a_cpu = torch .randn (16 , 64 , 28 , 28 )
18- b_cpu = torch .randn (16 , 64 , 28 , 28 )
17+ a_cpu = torch .randn (16 , 64 , 28 , 28 ). to ( dtype )
18+ b_cpu = torch .randn (16 , 64 , 28 , 28 ). to ( dtype )
1919 a_xpu = a_cpu .to (xpu_device )
2020 b_xpu = b_cpu .to (xpu_device )
2121 # naive
@@ -27,6 +27,22 @@ def test_copy_and_clone(self, dtype):
2727 b_xpu = a_xpu .clone (memory_format = torch .channels_last )
2828 self .assertEqual (b_cpu , b_xpu .to (cpu_device ))
2929
30+ def test_copy_and_clone_float4 (self ):
31+ # Float4_e2m1fn_x2 copy is not implemented by CPU
32+ a_cpu = torch .randn (16 , 64 , 28 , 28 ).to (torch .uint8 )
33+ b_cpu = torch .randn (16 , 64 , 28 , 28 ).to (torch .uint8 )
34+ a_xpu = a_cpu .to (xpu_device ).view (torch .float4_e2m1fn_x2 )
35+ b_xpu = b_cpu .to (xpu_device ).view (torch .float4_e2m1fn_x2 )
36+
37+ b_cpu .copy_ (a_cpu )
38+ b_xpu .copy_ (a_xpu )
39+ # Float4_e2m1fn_x2 compare is not implemented CPU
40+ self .assertEqual (b_cpu , b_xpu .view (torch .uint8 ).to (cpu_device ))
41+
42+ b_cpu = a_cpu .clone (memory_format = torch .channels_last )
43+ b_xpu = a_xpu .clone (memory_format = torch .channels_last )
44+ self .assertEqual (b_cpu , b_xpu .view (torch .uint8 ).to (cpu_device ))
45+
3046
3147instantiate_device_type_tests (TestSimpleCopy , globals (), only_for = "xpu" , allow_xpu = True )
3248
0 commit comments