Skip to content

Commit aeeb7f7

Browse files
committed
Add Float8_e8m0fnu support to copy
1 parent 9aac5a1 commit aeeb7f7

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

src/ATen/native/xpu/sycl/CopyKernel.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,21 @@ void float8_copy_kernel_xpu(TensorIteratorBase& iter) {
8888
gpu_kernel(iter, CopyScalarFunc<Float8_e5m2fnuz>());
8989
break;
9090
}
91+
} else if (dtype == kFloat8_e8m0fnu) {
92+
switch (other_dtype) {
93+
case kFloat:
94+
gpu_kernel_nocast(iter, CastScalarFunc<float, Float8_e8m0fnu>());
95+
break;
96+
case kHalf:
97+
gpu_kernel_nocast(iter, CastScalarFunc<Half, Float8_e8m0fnu>());
98+
break;
99+
case kBFloat16:
100+
gpu_kernel_nocast(iter, CastScalarFunc<BFloat16, Float8_e8m0fnu>());
101+
break;
102+
default:
103+
gpu_kernel(iter, CopyScalarFunc<Float8_e8m0fnu>());
104+
break;
105+
}
91106
} else {
92107
TORCH_CHECK(
93108
false,
@@ -114,11 +129,8 @@ void copy_kernel(TensorIteratorBase& iter) {
114129
kBool,
115130
kBFloat16,
116131
kComplexHalf,
117-
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
118-
kFloat8_e4m3fn,
119-
kFloat8_e5m2,
120-
kFloat8_e4m3fnuz,
121-
kFloat8_e5m2fnuz);
132+
AT_EXPAND(AT_FLOAT8_TYPES),
133+
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
122134
}
123135
}
124136

test/regressions/test_copy.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
# Owner(s): ["module: intel"]
22
import torch
3-
from torch.testing._internal.common_utils import TestCase
3+
from torch.testing._internal.common_device_type import (
4+
dtypes,
5+
instantiate_device_type_tests,
6+
)
7+
from torch.testing._internal.common_dtype import float8_types_and
8+
from torch.testing._internal.common_utils import run_tests, TestCase
49

510
cpu_device = torch.device("cpu")
611
xpu_device = torch.device("xpu")
712

813

914
class TestSimpleCopy(TestCase):
10-
def test_copy_and_clone(self, dtype=torch.float):
15+
@dtypes(*float8_types_and(torch.float8_e8m0fnu, torch.float32))
16+
def test_copy_and_clone(self, dtype):
1117
a_cpu = torch.randn(16, 64, 28, 28)
1218
b_cpu = torch.randn(16, 64, 28, 28)
1319
a_xpu = a_cpu.to(xpu_device)
@@ -20,3 +26,10 @@ def test_copy_and_clone(self, dtype=torch.float):
2026
b_cpu = a_cpu.clone(memory_format=torch.channels_last)
2127
b_xpu = a_xpu.clone(memory_format=torch.channels_last)
2228
self.assertEqual(b_cpu, b_xpu.to(cpu_device))
29+
30+
31+
instantiate_device_type_tests(TestSimpleCopy, globals(), only_for="xpu", allow_xpu=True)
32+
33+
34+
if __name__ == "__main__":
35+
run_tests()

0 commit comments

Comments
 (0)