Skip to content

erranlli/bfloat16-fp8-conversion

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Bfloat16 Conversion to FP8 and Back

Implements two python functions as a custom C++ extension.

@torch.jit.script
def round_to_fp8_represented_as_int8(
    t: torch.Tensor,
    n_mantissa: int,
    out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    return torch.ops.custom.round_to_fp8_represented_as_int8(t, n_mantissa, out)

@torch.jit.script
def undo_int8_fp8(
    fp8_tensor: torch.Tensor,
    n_mantissa: int,
    out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    return torch.ops.custom.undo_int8_fp8(fp8_tensor, n_mantissa, out)

Overview

Custom C++ Extension: wrote a C++ extension using PyTorch's C++ API (ATen and torch::Tensor types) to manipulate bits at a low level.

TorchScript Compatibility: the custom functions can be called from TorchScript by registering them appropriately.

Implement the two functions in C++:

fp8_extension.cpp

Use low-level bit manipulation to convert between bfloat16 and FP8 representations.

Use setuptools and torch.utils.cpp_extension to compile the extension.

Build the Extension:

python setup.py clean
python setup.py build_ext --inplace

Register the custom operators in fp8_extension.cpp so they can be used in TorchScript.

TORCH_LIBRARY(custom, m) {
    m.def("round_to_fp8_represented_as_int8(Tensor t, int n_mantissa, Tensor? out=None) -> Tensor", &round_to_fp8_represented_as_int8);
    m.def("undo_int8_fp8(Tensor fp8_tensor, int n_mantissa, Tensor? out=None) -> Tensor", &undo_int8_fp8);
}

// Define the module to be imported from Python
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("round_to_fp8_represented_as_int8", &round_to_fp8_represented_as_int8,
          "Convert bfloat16 to FP8 (int8) with stochastic rounding",
          pybind11::arg("t"), pybind11::arg("n_mantissa"), pybind11::arg("out") = torch::nullopt);

    m.def("undo_int8_fp8", &undo_int8_fp8,
          "Convert FP8 (int8) to bfloat16",
          pybind11::arg("fp8_tensor"), pybind11::arg("n_mantissa"), pybind11::arg("out") = torch::nullopt);
}

Python Wrapper:

solution.py

Make sure, replace the following line with the library name built in your environment

torch.ops.load_library('./fp8_extension.cpython-310-x86_64-linux-gnu.so')

Test the functions with python unit tests.

python test_fp8_extension.py

Implementation Details

Bfloat16 to FP8

  • E4M3 and E5M2: Need to handle nan/infinities, overflow differently since E4M3 has no infinities and they have different bit patterns of NaN.

  • Subnormal handling: When Bfloat16 number becomes a subnormal in FP8,

    • need to make the implicit bit explicit
    • Need to adjust the exponent
  • Schostic rounding:

    • Need to handle mantissa overflow for normal
    • Mantissa overflow in subnormals is tricky: for bfloat16 value of exponent 01110 0000 and mantissa 1100 000, if rounded up in FP8 E5M2, it becomes 00001 00, the minimal normal; on the other hand, if rounded down, it is 00000 11, the max subnormal.

FP8 to Bfloat16

  • Subnormal handling: When FP8 subnormal becomes a normal in Bfloat16,
    • need to shit the first 1 bit out and make it implicit
    • Need to adjust the exponent

Unit Test Details

Due to stochastic rounding, Bfloat16 to FP8 conversion can result in two numbers. Since there are only 256 unique values, we precompute and order them. In unit tests, we just check whether the conversion generates one of them or both if we repeat.

Test with Random Generated Numbers

Generate a large amount of random numbers 2x the range of FP8.

Test with Known Values

We test with know values, edge cases and boundary values. Specifically, boundary values can round up to infinities and round down to the max positive value and minimal negative value.

Test Stochastic Ronding Algorithm

Generate values that are exactly halfway between two FP8 representable values. With enough trials, we should see the two nearest values.

Sample Output

python test_fp8_extension.py 

Running test_fp8_conversion_known_values:
  Defined 90 special case values for testing.
  Testing with 2 mantissa bits.
    Added 15872 boundary values to special cases.
    Total values for conversion: 15962
    FP8 conversion check passed for 2 mantissa bits.
  Testing with 3 mantissa bits.
    Added 60 boundary values to special cases.
    Total values for conversion: 150
    FP8 conversion check passed for 3 mantissa bits.
.
Running test_fp8_conversion_random:
  Testing with 2 mantissa bits and max_value=131072.0
    Generated 10240 random values.
    FP8 conversion check passed for 2 mantissa bits.
  Testing with 3 mantissa bits and max_value=1024.0
    Generated 10240 random values.
    FP8 conversion check passed for 3 mantissa bits.
.
Running test_stochastic_rounding_correctness:
  Testing stochastic rounding with 2 mantissa bits.
    Generated 246 midpoint values for testing.
    Converted midpoints to bfloat16.
    Performing 1000 stochastic rounding trials.
      Completed 200/1000 trials.
      Completed 400/1000 trials.
      Completed 600/1000 trials.
      Completed 800/1000 trials.
      Completed 1000/1000 trials.
    Completed all 1000 trials.
      Passed stochastic rounding check for value index 0/246.
      Passed stochastic rounding check for value index 10/246.
      Passed stochastic rounding check for value index 20/246.
      Passed stochastic rounding check for value index 30/246.
      Passed stochastic rounding check for value index 40/246.
      Passed stochastic rounding check for value index 50/246.
      Passed stochastic rounding check for value index 60/246.
      Passed stochastic rounding check for value index 70/246.
      Passed stochastic rounding check for value index 80/246.
      Passed stochastic rounding check for value index 90/246.
      Passed stochastic rounding check for value index 100/246.
      Passed stochastic rounding check for value index 110/246.
      Passed stochastic rounding check for value index 120/246.
      Passed stochastic rounding check for value index 130/246.
      Passed stochastic rounding check for value index 140/246.
      Passed stochastic rounding check for value index 150/246.
      Passed stochastic rounding check for value index 160/246.
      Passed stochastic rounding check for value index 170/246.
      Passed stochastic rounding check for value index 180/246.
      Passed stochastic rounding check for value index 190/246.
      Passed stochastic rounding check for value index 200/246.
      Passed stochastic rounding check for value index 210/246.
      Passed stochastic rounding check for value index 220/246.
      Passed stochastic rounding check for value index 230/246.
      Passed stochastic rounding check for value index 240/246.
    Stochastic rounding correctness verified for 2 mantissa bits.
  Testing stochastic rounding with 3 mantissa bits.
    Generated 252 midpoint values for testing.
    Converted midpoints to bfloat16.
    Performing 1000 stochastic rounding trials.
      Completed 200/1000 trials.
      Completed 400/1000 trials.
      Completed 600/1000 trials.
      Completed 800/1000 trials.
      Completed 1000/1000 trials.
    Completed all 1000 trials.
      Passed stochastic rounding check for value index 0/252.
      Passed stochastic rounding check for value index 10/252.
      Passed stochastic rounding check for value index 20/252.
      Passed stochastic rounding check for value index 30/252.
      Passed stochastic rounding check for value index 40/252.
      Passed stochastic rounding check for value index 50/252.
      Passed stochastic rounding check for value index 60/252.
      Passed stochastic rounding check for value index 70/252.
      Passed stochastic rounding check for value index 80/252.
      Passed stochastic rounding check for value index 90/252.
      Passed stochastic rounding check for value index 100/252.
      Passed stochastic rounding check for value index 110/252.
      Passed stochastic rounding check for value index 120/252.
      Passed stochastic rounding check for value index 130/252.
      Passed stochastic rounding check for value index 140/252.
      Passed stochastic rounding check for value index 150/252.
      Passed stochastic rounding check for value index 160/252.
      Passed stochastic rounding check for value index 170/252.
      Passed stochastic rounding check for value index 180/252.
      Passed stochastic rounding check for value index 190/252.
      Passed stochastic rounding check for value index 200/252.
      Passed stochastic rounding check for value index 210/252.
      Passed stochastic rounding check for value index 220/252.
      Passed stochastic rounding check for value index 230/252.
      Passed stochastic rounding check for value index 240/252.
      Passed stochastic rounding check for value index 250/252.
    Stochastic rounding correctness verified for 3 mantissa bits.
.
----------------------------------------------------------------------
Ran 3 tests in 3.343s

OK

Original Instruction: FP8 stochastic rounding

  • Write a python function that converts a torch tensor of type {bfloat16} to fp8 with N mantissa bits (N is an argument to this function).
  • The fp8 tensor should be stored as a uint8 tensor because not all GPUs support fp8 natively. Cast to fp8, not to uint8. Clarifying Note: This task is about shifting bits, not using .to(torch.uint8).
  • Also, write a function to convert the int8-based tensor back to bfloat16
  • Your function should stochastically round (https://nhigham.com/2020/07/07/what-is-stochastic-rounding/) the source tensor. Note that there are edge cases, all of which should be considered.
  • Write unit test that assert the expected value of the casting function is close to the source tensor to validate your stochastic rounding implementation. The unit test is part of the task, and we're looking for detailed and thorough assertions that target the key functionality precisely.

Environment

  • Use a py39 environment
  • Ensure that your task can be executed on an x86 CPU
  • Start from src/solution.py

FAQ

Can I use numpy?

No, only pytorch for the solution. You may use numpy for your unit tests.

Do I have to consider edge cases?

Yes, your solution should work for NaN, inf and subnormals

For which fp8 datatypes should my solution work?

E4M3 and E5M2

How long should this take me?

Successful candidates usually take between a few hours and a full day to arrive at a complete solution

How do I submit the task

Invite @magic-screening-tasks-gh-sa and @EricSteinberger to your private github repo and let us know you are ready.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published