Skip to content

Commit

Permalink
feat (utils): Added decorator to specify the default datatype
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Dec 19, 2024
1 parent 9af5ace commit 0edb3df
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
19 changes: 19 additions & 0 deletions src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import copy
from functools import wraps
from typing import List, Optional, Tuple

import torch
Expand Down Expand Up @@ -127,3 +128,21 @@ def is_broadcastable(tensor, other):
else:
return False
return True


def torch_dtype(dtype):

def decorator(fn):

@wraps(fn)
def wrapped_fn(*args, **kwargs):
cur_dtype = torch.get_default_dtype()
try:
torch.set_default_dtype(dtype)
fn(*args, **kwargs)
finally:
torch.set_default_dtype(cur_dtype)

return wrapped_fn

return decorator
3 changes: 2 additions & 1 deletion tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from brevitas.core.scaling import FloatScaling
from brevitas.function.ops import max_float
from brevitas.utils.torch_utils import float_internal_scale
from brevitas.utils.torch_utils import torch_dtype
from tests.brevitas.hyp_helper import float_st
from tests.brevitas.hyp_helper import float_tensor_random_shape_st
from tests.brevitas.hyp_helper import random_minifloat_format
Expand Down Expand Up @@ -242,9 +243,9 @@ def test_inner_scale(inp, minifloat_format, scale):
min_bit_width=4, max_bit_with=10, rand_exp_bias=True))
@settings(max_examples=10000)
@jit_disabled_for_mock()
@torch_dtype(torch.float64)
@torch.no_grad()
def test_valid_float_values(minifloat_format_and_value):
torch.set_default_dtype(torch.float64)
minifloat_value, exponent, mantissa, sign, bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format_and_value
scaling_impl = mock.Mock(side_effect=lambda x, y: 1.0)
float_scaling = FloatScaling(None, None, True)
Expand Down

0 comments on commit 0edb3df

Please sign in to comment.