-
Notifications
You must be signed in to change notification settings - Fork 202
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix (minifloat): compute max_value during dependency injection
- Loading branch information
1 parent
1b2a64b
commit 49489b2
Showing
6 changed files
with
79 additions
and
126 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,66 @@ | ||
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def mantissa_bits_to_float(bits: Tensor, frexp_compatible: bool = False) -> float: | ||
def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False) -> float: | ||
# computes the decimal place value from a given binary tensor | ||
res = 1.0 | ||
for i, val in enumerate(bits): | ||
# iterating through from left to right | ||
res += ((2 ** -(i + 1)) * val) | ||
res += ((2 ** -(i + 1)) * float(val)) | ||
if frexp_compatible: | ||
return res / 2. | ||
else: | ||
return res | ||
|
||
|
||
def get_minifloat_value(exponent: Tensor, mantissa: Tensor, exponent_bias: Tensor) -> Tensor: | ||
def get_minifloat_value(exponent: str, mantissa: str, exponent_bias: int) -> float: | ||
""" | ||
Returns the minifloat value for a given exponent, mantissa and exponent_bias. | ||
It expects the exponent and mantissa in their binary format. | ||
""" | ||
exponent_value = bits_to_dec(exponent) | ||
exponent_value = int(exponent, 2) | ||
mantissa_value = mantissa_bits_to_float(mantissa) | ||
return torch.exp2(exponent_value - exponent_bias) * mantissa_value | ||
return 2 ** (exponent_value - exponent_bias) * mantissa_value | ||
|
||
|
||
def get_max_value(exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values): | ||
# Idea: take the smallest NaN/inf value, set max_value to the next smaller one | ||
# inf without NaN not possible | ||
if inf_values is None and nan_values is None: | ||
# no special cases, max_value is using all bits for exponent and mantissa | ||
exponent = '1' * exponent_bit_width | ||
mantissa = '1' * mantissa_bit_width | ||
elif nan_values is not None: | ||
# we at least have values for NaN, so initiate MaxValInfNaN | ||
special_values = nan_values + inf_values if inf_values is not None else nan_values | ||
|
||
# check that NaN/inf values are all mantissa_bit_width long | ||
if any(map(lambda x: len(x) > mantissa_bit_width, special_values)): | ||
raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.') | ||
|
||
def dec_to_bits(value: Tensor, bits: int) -> Tensor: | ||
# set up mask | ||
mask = 2 ** torch.arange(bits - 1, -1, -1).to(value.device, value.dtype) | ||
# add dimension, bitwise_and gets the bits needed for the value, the rest is converting to byte | ||
return value.unsqueeze(-1).bitwise_and(mask).ne(0).byte() | ||
# get the minimum special case, our max value is the next smaller value | ||
min_special_case = min(map(lambda x: int(x, 2), special_values)) | ||
|
||
max_value_mantissa = min_special_case - 1 | ||
|
||
if max_value_mantissa < 0: | ||
# all mantissa values are used, so we need to use decrease exponent values | ||
exponent = '1' * (exponent_bit_width - 1) | ||
# add trailing 0 to reach bit width | ||
exponent += '0' | ||
# since we decreased exponent, we can use full mantissa | ||
mantissa = '1' * mantissa_bit_width | ||
else: | ||
# there is a free mantissa code, so use full exponent | ||
exponent = '1' * exponent_bit_width | ||
# get binary code for max_value_mantissa in the number of mantissa bits | ||
mantissa = format(max_value_mantissa, f'0{mantissa_bit_width}b') | ||
else: | ||
# no NaN values but inf values | ||
raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.') | ||
|
||
def bits_to_dec(bits: Tensor) -> Tensor: | ||
# get num of bits used | ||
num_bits = len(bits) | ||
# convert by summing decimal values of set bits | ||
return torch.sum((2 ** torch.arange(num_bits - 1, -1, -1)) * bits) | ||
# we don't need the sign since we're looking for the max value | ||
max_value = get_minifloat_value( | ||
exponent=exponent, mantissa=mantissa, exponent_bias=exponent_bias) | ||
return max_value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters