Skip to content

Enable FP8/MXFP8 Ops with requests and CUDA alignment #2207

@CuiYifeng

Description

@CuiYifeng

🚀 The feature, motivation and pitch

Plan to enable the following ops for FP8/MXFP8:
🟢 Supported 🟡 TBD ❌Rejected

Memory Op e4m3fn e4m3fnuz e5m2 e5m2fnuz e8m0fnu PR Link
fill/fill_ 🟢 🟢 🟢 🟢 🟢
flip/fliplr/flipud 🟢 🟢 🟢 🟢 🟢 #2190
index_put/index_put_ 🟢 🟢 🟢 🟢 #2190
index.Tensor/index.Tensor_out 🟢 🟢 🟢 🟢 🟢 #2190
index_select/index_select.out 🟢 🟢 🟢 🟢 🟢
gather/gather.out 🟢 🟢 🟢 🟢 🟢
cat/cat.out 🟢 🟢 🟢 🟢 🟢 #2152
eq/eq_ 🟢 🟢 🟢 🟢 🟢 #2152
ne/ne_ 🟢 🟢 🟢 🟢 🟢
where 🟢 🟢 🟢 🟢 🟢 #2152
empty/zeros/ones 🟢 🟢 🟢 🟢 🟢
to 🟢 🟢 🟢 🟢 🟢
copy (Float8_e8m0fnu) 🟢 🟢 🟢 🟢 🟢 #2258
clone 🟢 🟢 🟢 🟢 🟢
add/sub/mul/div 🟡 🟡 🟡 🟡 🟡 #2145
compare 🟡 🟡 🟡 🟡 🟡 #2154
normal 🟡 🟡 🟡 🟡 🟡
GEMM OP Activation&Weight Scale Modeling Scale data type Scale layout PR Link
_scaled_mm FP8(E4M3/E5M2) Tensorwise scaling FP32 Scalar pytorch/pytorch#165978
FP8(E4M3/E5M2) Channelwise scaling FP32 Vector pytorch/pytorch#165978
FP8(E4M3/E5M2) 128-element 1D/128x128-element 2D block scaling FP32 Tensor
MxFP8(E4M3/E5M2) 32-element 1D block scaling UE8M0 Tiled Tensor
_scaled_grouped_mm FP8(E4M3/E5M2) Tensorwise scaling FP32 Scalar
FP8(E4M3/E5M2) Channelwise scaling FP32 Vector
FP8(E4M3/E5M2) 128-element 1D/128x128-element 2D block scaling FP32 Tensor
MxFP8(E4M3/E5M2) 32-element 1D block scaling UE8M0 Tiled Tensor

Alternatives

No response

Additional context

No response

Metadata

Metadata

Labels

No labels
No labels

Type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions