diff --git a/src/ATen/native/xpu/sycl/CopyKernel.cpp b/src/ATen/native/xpu/sycl/CopyKernel.cpp index 9a8a92ad5..e4fe93a57 100644 --- a/src/ATen/native/xpu/sycl/CopyKernel.cpp +++ b/src/ATen/native/xpu/sycl/CopyKernel.cpp @@ -88,6 +88,21 @@ void float8_copy_kernel_xpu(TensorIteratorBase& iter) { gpu_kernel(iter, CopyScalarFunc()); break; } + } else if (dtype == kFloat8_e8m0fnu) { + switch (other_dtype) { + case kFloat: + gpu_kernel_nocast(iter, CastScalarFunc()); + break; + case kHalf: + gpu_kernel_nocast(iter, CastScalarFunc()); + break; + case kBFloat16: + gpu_kernel_nocast(iter, CastScalarFunc()); + break; + default: + gpu_kernel(iter, CopyScalarFunc()); + break; + } } else { TORCH_CHECK( false, @@ -114,11 +129,8 @@ void copy_kernel(TensorIteratorBase& iter) { kBool, kBFloat16, kComplexHalf, - AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), - kFloat8_e4m3fn, - kFloat8_e5m2, - kFloat8_e4m3fnuz, - kFloat8_e5m2fnuz); + AT_EXPAND(AT_FLOAT8_TYPES), + AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } } diff --git a/test/regressions/test_copy.py b/test/regressions/test_copy.py index ff90efce0..363f85002 100644 --- a/test/regressions/test_copy.py +++ b/test/regressions/test_copy.py @@ -1,13 +1,19 @@ # Owner(s): ["module: intel"] import torch -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_device_type import ( + dtypes, + instantiate_device_type_tests, +) +from torch.testing._internal.common_dtype import float8_types_and +from torch.testing._internal.common_utils import run_tests, TestCase cpu_device = torch.device("cpu") xpu_device = torch.device("xpu") class TestSimpleCopy(TestCase): - def test_copy_and_clone(self, dtype=torch.float): + @dtypes(*float8_types_and(torch.float8_e8m0fnu, torch.float32)) + def test_copy_and_clone(self, dtype): a_cpu = torch.randn(16, 64, 28, 28) b_cpu = torch.randn(16, 64, 28, 28) a_xpu = a_cpu.to(xpu_device) @@ -20,3 +26,10 @@ def test_copy_and_clone(self, dtype=torch.float): b_cpu = a_cpu.clone(memory_format=torch.channels_last) b_xpu = a_xpu.clone(memory_format=torch.channels_last) self.assertEqual(b_cpu, b_xpu.to(cpu_device)) + + +instantiate_device_type_tests(TestSimpleCopy, globals(), only_for="xpu", allow_xpu=True) + + +if __name__ == "__main__": + run_tests()