-
Notifications
You must be signed in to change notification settings - Fork 199
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix (minifloat): fix jit issues with FloatClamp
- Loading branch information
1 parent
966085e
commit 3db9fdc
Showing
6 changed files
with
76 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,41 @@ | ||
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False) -> float: | ||
def mantissa_bits_to_float(bits: Tensor, frexp_compatible: bool = False) -> float: | ||
# computes the decimal place value from a given binary tensor | ||
res = 1.0 | ||
for i, val in enumerate(bits): | ||
# iterating through from left to right | ||
res += ((2 ** -(i + 1)) * float(val)) | ||
res += ((2 ** -(i + 1)) * val) | ||
if frexp_compatible: | ||
return res / 2. | ||
else: | ||
return res | ||
|
||
|
||
def get_minifloat_value( | ||
exponent_string: str, | ||
mantissa_string: str, | ||
exponent_bias: Tensor, | ||
sign: str = '0') -> float: | ||
exponent_value = int(exponent_string, 2) | ||
mantissa_value = mantissa_bits_to_float(mantissa_string) | ||
return ((-1) ** float(sign)) * 2 ** (exponent_value - exponent_bias) * mantissa_value | ||
def get_minifloat_value(exponent: Tensor, mantissa: Tensor, exponent_bias: Tensor) -> Tensor: | ||
""" | ||
Returns the minifloat value for a given exponent, mantissa and exponent_bias. | ||
It expects the exponent and mantissa in their binary format. | ||
""" | ||
exponent_value = bits_to_dec(exponent) | ||
mantissa_value = mantissa_bits_to_float(mantissa) | ||
return torch.exp2(exponent_value - exponent_bias) * mantissa_value | ||
|
||
|
||
def dec_to_bits(value: Tensor, bits: int) -> Tensor: | ||
# set up mask | ||
mask = 2 ** torch.arange(bits - 1, -1, -1).to(value.device, value.dtype) | ||
# add dimension, bitwise_and gets the bits needed for the value, the rest is converting to byte | ||
return value.unsqueeze(-1).bitwise_and(mask).ne(0).byte() | ||
|
||
|
||
def bits_to_dec(bits: Tensor) -> Tensor: | ||
# get num of bits used | ||
num_bits = len(bits) | ||
# convert by summing decimal values of set bits | ||
return torch.sum((2 ** torch.arange(num_bits - 1, -1, -1)) * bits) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters