Skip to content

Commit 9aac5a1

Browse files
Support FP8 in op flip, index_put, and index.Tensor (#2190)
To solve #2207 Extends support for float8 data types across various XPU tensor indexing and transformation kernels, ensuring these operations are compatible with the new types. It also adds a regression test for flipping float8 tensors and removes the skip for float8 indexing tests. **Float8 type support:** * Updated dispatch macros in `XPUScalar.cpp` and `Indexing.cpp` to include `AT_FLOAT8_TYPES`, enabling float8 support in scalar extraction, indexing, index_put, and deterministic index_put kernels. * Modified `flip_kernel` in `TensorTransformationsKernels.cpp` to support float8 and barebones unsigned types, updating the dispatch mechanism accordingly. * Included the new dispatch header `Dispatch_v2.h` for the updated dispatch macros. **Testing improvements:** * Added a regression test for flipping float8 tensors in `test_index_and_index_put.py` to verify correctness of the operation on XPU. * Removed the skip for float8 tests in `test_indexing_xpu.py`, re-enabling these tests now that support is implemented. --------- Co-authored-by: Cui, Yifeng <[email protected]>
1 parent a3efbb3 commit 9aac5a1

File tree

5 files changed

+74
-40
lines changed

5 files changed

+74
-40
lines changed

src/ATen/native/xpu/XPUScalar.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Scalar _local_scalar_dense_xpu(const Tensor& self) {
3232
r = Scalar(*value.const_data_ptr<scalar_t>());
3333
}),
3434
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
35+
AT_EXPAND(AT_FLOAT8_TYPES),
3536
kComplexHalf,
3637
kHalf,
3738
kBool,

src/ATen/native/xpu/sycl/Indexing.cpp

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,10 @@ void index_kernel(
4343
TensorIteratorBase& iter,
4444
IntArrayRef index_size,
4545
IntArrayRef index_stride) {
46-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
47-
at::ScalarType::ComplexHalf,
48-
at::ScalarType::BFloat16,
49-
at::ScalarType::Half,
50-
at::ScalarType::Bool,
46+
AT_DISPATCH_V2(
5147
iter.dtype(),
5248
"index_xpu",
53-
[&] {
49+
AT_WRAP([&] {
5450
using dtype = OpaqueType<sizeof(scalar_t)>;
5551
IndexFunctor<dtype> f;
5652
_index_kernel(
@@ -61,7 +57,13 @@ void index_kernel(
6157
IntArrayRef{},
6258
f,
6359
true);
64-
});
60+
}),
61+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
62+
AT_EXPAND(AT_FLOAT8_TYPES),
63+
kComplexHalf,
64+
kHalf,
65+
kBool,
66+
kBFloat16);
6567
}
6668

6769
template <typename ValType>
@@ -588,14 +590,10 @@ void index_put_kernel(
588590
false);
589591
});
590592
} else {
591-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
592-
at::ScalarType::ComplexHalf,
593-
at::ScalarType::BFloat16,
594-
at::ScalarType::Half,
595-
at::ScalarType::Bool,
593+
AT_DISPATCH_V2(
596594
iter.dtype(),
597595
"index_put_xpu",
598-
[&] {
596+
AT_WRAP([&] {
599597
using dtype = OpaqueType<sizeof(scalar_t)>;
600598
IndexPutFunctor<dtype> f;
601599
_index_kernel(
@@ -606,7 +604,13 @@ void index_put_kernel(
606604
IntArrayRef{},
607605
f,
608606
false);
609-
});
607+
}),
608+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
609+
AT_EXPAND(AT_FLOAT8_TYPES),
610+
kComplexHalf,
611+
kHalf,
612+
kBool,
613+
kBFloat16);
610614
}
611615
}
612616

@@ -693,14 +697,10 @@ void index_put_deterministic_kernel(
693697
expandedValue.numel());
694698

695699
if (sliceSize > SIMD) {
696-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
697-
at::ScalarType::ComplexHalf,
698-
at::ScalarType::BFloat16,
699-
at::ScalarType::Half,
700-
at::ScalarType::Bool,
700+
AT_DISPATCH_V2(
701701
expandedValue.scalar_type(),
702702
"index_put_deterministic_kernel",
703-
[&] {
703+
AT_WRAP([&] {
704704
launch_index_put_deterministic_kernel<scalar_t, scalar_t>(
705705
sorted_indices.mutable_data_ptr<int64_t>(),
706706
orig_indices.mutable_data_ptr<int64_t>(),
@@ -711,17 +711,24 @@ void index_put_deterministic_kernel(
711711
strideBefore,
712712
nElemBefore,
713713
accumulate);
714-
});
714+
}),
715+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
716+
// TODO: Enable AT_FLOAT8_DTYPES after accumulation behavior is
717+
// cleared for float8 dtypes.
718+
kFloat8_e4m3fn,
719+
kFloat8_e5m2,
720+
kFloat8_e4m3fnuz,
721+
kFloat8_e5m2fnuz,
722+
kComplexHalf,
723+
kHalf,
724+
kBool,
725+
kBFloat16);
715726
} else {
716727
// Align acc type with CUDA
717-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
718-
at::ScalarType::ComplexHalf,
719-
at::ScalarType::BFloat16,
720-
at::ScalarType::Half,
721-
at::ScalarType::Bool,
728+
AT_DISPATCH_V2(
722729
expandedValue.scalar_type(),
723730
"index_put_deterministic_kernel",
724-
[&] {
731+
AT_WRAP([&] {
725732
using accscalar_t = at::opmath_type<scalar_t>;
726733
launch_index_put_deterministic_kernel<scalar_t, accscalar_t>(
727734
sorted_indices.mutable_data_ptr<int64_t>(),
@@ -733,7 +740,18 @@ void index_put_deterministic_kernel(
733740
strideBefore,
734741
nElemBefore,
735742
accumulate);
736-
});
743+
}),
744+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
745+
// TODO: Enable AT_FLOAT8_DTYPES after accumulation behavior is
746+
// cleared for float8 dtypes.
747+
kFloat8_e4m3fn,
748+
kFloat8_e5m2,
749+
kFloat8_e4m3fnuz,
750+
kFloat8_e5m2fnuz,
751+
kComplexHalf,
752+
kHalf,
753+
kBool,
754+
kBFloat16);
737755
}
738756

739757
if (permuted)

src/ATen/native/xpu/sycl/TensorTransformationsKernels.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
22
#include <ATen/Dispatch.h>
3+
#include <ATen/Dispatch_v2.h>
34
#include <ATen/WrapDimUtilsMulti.h>
45
#include <ATen/native/xpu/sycl/MemoryAccess.h>
56
#include <ATen/native/xpu/sycl/OffsetCalculator.h>
@@ -129,16 +130,20 @@ void flip_kernel(TensorIterator& iter, bool quantized) {
129130
if (quantized) {
130131
TORCH_CHECK(false, "XPU current does not flip for quantized tensor");
131132
}
132-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
133-
at::ScalarType::Half,
134-
at::ScalarType::Bool,
135-
at::ScalarType::BFloat16,
133+
AT_DISPATCH_V2(
136134
iter.dtype(),
137135
"flip_xpu",
138-
[&] {
136+
AT_WRAP([&] {
139137
using dtype = OpaqueType<sizeof(scalar_t)>;
140138
flip_kernel_impl<dtype>(iter);
141-
});
139+
}),
140+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
141+
AT_EXPAND(AT_FLOAT8_TYPES),
142+
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
143+
kComplexHalf,
144+
kHalf,
145+
kBool,
146+
kBFloat16);
142147
}
143148

144149
template <typename scalar_t>

test/regressions/test_index_and_index_put.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,18 @@ def test_index_put_with_zero_shape_dim(self, dtype=torch.bfloat16):
9696
b = torch.randn([5, 0], dtype=dtype, device=torch.device("xpu"))
9797
a[:5, :] = a[:5, :] * 2 + b
9898
torch.use_deterministic_algorithms(False)
99+
100+
def test_flip_float8(self):
101+
FLOAT8_DTYPES = (
102+
torch.float8_e4m3fn,
103+
torch.float8_e4m3fnuz,
104+
torch.float8_e5m2,
105+
torch.float8_e5m2fnuz,
106+
torch.float8_e8m0fnu,
107+
)
108+
for dtype in FLOAT8_DTYPES:
109+
a_cpu = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dtype)
110+
a_xpu = a_cpu.to("xpu")
111+
b_cpu = torch.flip(a_cpu, [0]).to(torch.float32)
112+
b_xpu = torch.flip(a_xpu, [0]).cpu().to(torch.float32)
113+
self.assertEqual(b_cpu, b_xpu)

test/xpu/skip_list_common.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,7 @@
281281
# x_cuda = x.clone().detach().to("cuda").requires_grad_(): Torch not compiled with CUDA enabled
282282
"test_layer_norm_backwards_eps",
283283
),
284-
"test_indexing_xpu.py": (
285-
# XPU implementation doesn't claimn FP8 now
286-
# https://github.com/intel/torch-xpu-ops/issues/461
287-
# https://github.com/intel/torch-xpu-ops/issues/1975
288-
"float8",
289-
),
284+
"test_indexing_xpu.py": None,
290285
"nn/test_pooling_xpu.py": None,
291286
"nn/test_dropout_xpu.py": None,
292287
"test_dataloader_xpu.py": None,

0 commit comments

Comments
 (0)