|
3 | 3 | # |
4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | | -from typing import Any, Dict |
| 6 | +from typing import Any, Dict, Tuple |
7 | 7 |
|
8 | 8 | import torch |
9 | 9 |
|
@@ -50,6 +50,52 @@ def float8_desugar_op(aten_op, args, kwargs=None): |
50 | 50 | ) |
51 | 51 |
|
52 | 52 |
|
| 53 | +@implements([aten.split.Tensor]) |
| 54 | +def float8_split(aten_op, args, kwargs=None): |
| 55 | + new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs) |
| 56 | + |
| 57 | + def make_float8(data): |
| 58 | + return Float8Tensor( |
| 59 | + data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config |
| 60 | + ) |
| 61 | + |
| 62 | + out = map(make_float8, new_data_tensors) |
| 63 | + return list(out) |
| 64 | + |
| 65 | + |
| 66 | +# Errors cant `cat_cuda float8 e4m3fn` |
| 67 | +@implements([aten.cat.default]) |
| 68 | +def float8_cat(aten_op, args, kwargs=None): |
| 69 | + chunked_tensors: Tuple[Float8Tensor] = args[0] |
| 70 | + |
| 71 | + orig_dtype = chunked_tensors[0]._orig_dtype |
| 72 | + scale = chunked_tensors[0]._scale |
| 73 | + mm_config = chunked_tensors[0]._mm_config |
| 74 | + fp8_dtype = chunked_tensors[0]._data.dtype |
| 75 | + chunk_data = [] |
| 76 | + for chunk in chunked_tensors: |
| 77 | + assert isinstance( |
| 78 | + chunk, Float8Tensor |
| 79 | + ), "Expecting all chunks to be of type Float8Tensor" |
| 80 | + assert ( |
| 81 | + chunk._orig_dtype == orig_dtype |
| 82 | + ), "Expecting all chunks to be of the same dtype" |
| 83 | + assert ( |
| 84 | + chunk._scale is scale |
| 85 | + ), "Expecting all chunks to have thee same scale as a result of a split" |
| 86 | + assert ( |
| 87 | + chunk._mm_config is mm_config |
| 88 | + ), "Expecting all chunks to have thee same mm config as a result of a split" |
| 89 | + assert ( |
| 90 | + chunk._data.dtype == fp8_dtype |
| 91 | + ), "Expecting all chunks to be of the same dtype as a result of a split" |
| 92 | + chunk_data.append(chunk._data.view(torch.uint8)) |
| 93 | + |
| 94 | + new_data = aten_op(chunk_data, *args[1:], **kwargs) |
| 95 | + new_data = new_data.view(fp8_dtype) |
| 96 | + return Float8Tensor(new_data, scale, orig_dtype, mm_config) |
| 97 | + |
| 98 | + |
53 | 99 | @implements([aten.sum.dim_IntList]) |
54 | 100 | def float8_cast_up_op(aten_op, args, kwargs=None): |
55 | 101 | """Be careful with this function, this is a "fallback" op that |
|
0 commit comments