Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quant tensor not empty #819

Merged
merged 32 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Cleanup
  • Loading branch information
Giuseppe5 committed Feb 22, 2024
commit 05677a4143f304446f0df80cdd41dc10c82f8bb3
5 changes: 2 additions & 3 deletions src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,9 @@ def _cache_fn_dispatcher(cls, fn, input, *args, **kwargs):
if isinstance(input, QuantTensor):
inp_cache = None
out_cache = None
if input.is_not_none:
inp_cache = _CachedIO(input, metadata_only=True)
inp_cache = _CachedIO(input, metadata_only=True)
output = fn(input, *args, **kwargs)
if isinstance(output, QuantTensor) and output.is_not_none:
if isinstance(output, QuantTensor):
out_cache = _CachedIO(output, metadata_only=True)
cached_io = (inp_cache, out_cache)
if fn in cls._cached_io_handler_map:
Expand Down
12 changes: 5 additions & 7 deletions src/brevitas/nn/hadamard_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,13 @@ def forward(self, inp):
out = inp.value / norm
out = nn.functional.linear(out, self.proj[:self.out_channels, :self.in_channels])
out = -self.scale * out
if inp.scale is not None:
if isinstance(inp, QuantTensor):
output_scale = inp.scale * self.scale / norm
if inp.bit_width is not None:
output_bit_width = self.max_output_bit_width(inp.bit_width)
if (self.return_quant_tensor and inp.zero_point is not None and
(inp.zero_point != 0.0).any()):
raise RuntimeError("Computing zero point of output accumulator not supported yet.")
else:
output_zp = inp.zero_point
if (self.return_quant_tensor and inp.zero_point != 0.0).any():
raise RuntimeError("Computing zero point of output accumulator not supported yet.")
else:
output_zp = inp.zero_point
out = QuantTensor(
value=out,
scale=output_scale,
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def forward(self, input: Union[Tensor, QuantTensor]):
return self.export_handler(x.value)
x = x.set(value=super(TruncAvgPool2d, self).forward(x.value))
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
if self.is_trunc_quant_enabled:
assert x.is_not_none # check input quant tensor is filled with values
assert isinstance(x, QuantTensor) # check input quant tensor is filled with values
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
# remove avg scaling
rescaled_value = x.value * self._avg_scaling
x = x.set(value=rescaled_value)
Expand Down Expand Up @@ -138,7 +138,7 @@ def forward(self, input: Union[Tensor, QuantTensor]):
self._cached_kernel_size = k_size
self._cached_kernel_stride = stride
if self.is_trunc_quant_enabled:
assert y.is_not_none # check input quant tensor is filled with values
assert isinstance(y, QuantTensor) # check input quant tensor is filled with values
reduce_size = reduce(mul, k_size, 1)
rescaled_value = y.value * reduce_size # remove avg scaling
y = y.set(value=rescaled_value)
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/nn/quant_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def forward(self, input: Union[Tensor, QuantTensor]):
y_value = interpolate(x.value, self.size, self.scale_factor, self.mode, self.align_corners)
if self.mode != 'nearest':
# round interpolated values to scale
assert x.scale is not None, 'Input scale factor required to interpolate correctly'
assert isinstance(x, QuantTensor), 'Input scale factor required to interpolate correctly'
y_value = round_ste(y_value / x.scale) * x.scale
y = x.set(value=y_value)
return self.pack_output(y)
Expand Down Expand Up @@ -73,7 +73,7 @@ def forward(self, input: Union[Tensor, QuantTensor]):
return out
y_value = interpolate(x.value, self.size, self.scale_factor, self.mode, self.align_corners)
# round interpolated values to scale
assert x.scale is not None, 'Input scale factor required to interpolate correctly'
assert isinstance(x, QuantTensor), 'Input scale factor required to interpolate correctly'
y_value = round_ste(y_value / x.scale) * x.scale
y = x.set(value=y_value)
return self.pack_output(y)
Expand Down
85 changes: 32 additions & 53 deletions src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,6 @@ def _unpack_quant_tensor(input_data):
return input_data


def _is_all_nested_not_none(input_data):
if isinstance(input_data, QuantTensor):
return input_data.is_not_none
elif isinstance(input_data, (tuple, list)):
return all([_is_all_nested_not_none(v) for v in input_data])
elif isinstance(input_data, dict):
return all([_is_all_nested_not_none(v) for v in input_data.values()])
else:
return True


class QuantTensor(QuantTensorBase):

def __new__(cls, value, scale, zero_point, bit_width, signed, training):
Expand Down Expand Up @@ -88,8 +77,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if (func not in QUANT_TENSOR_FN_HANDLER or
not all(issubclass(t, QuantTensor) for t in types) or
not (_is_all_nested_not_none(args) and _is_all_nested_not_none(kwargs))):
not all(issubclass(t, QuantTensor) for t in types)):
args = _unpack_quant_tensor(args)
kwargs = _unpack_quant_tensor(kwargs)
return func(*args, **kwargs)
Expand All @@ -99,12 +87,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
def tensor(self):
return self.value

@property
def is_not_none(self):
return (
self.value is not None and self.scale is not None and self.zero_point is not None and
self.bit_width is not None and self.signed is not None)

@property
def _pre_round_int_value(self):
value = self.value
Expand All @@ -120,30 +102,27 @@ def _pre_round_int_value(self):

@property
def is_valid(self):
if self.is_not_none:
with torch.no_grad():
pre_round_int_value = self._pre_round_int_value
rounded_int_value = torch.round(pre_round_int_value)
max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value))
atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL
is_int = max_abs_diff < atol
if self.bit_width >= 2:
if self.signed:
is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all()
is_lower_b = (-2.0 ** (self.bit_width - 1) <= rounded_int_value).all()
else:
is_upper_b = (2.0 ** self.bit_width - 1 >= rounded_int_value).all()
is_lower_b = (0. <= rounded_int_value).all()
return (is_int & is_upper_b & is_lower_b).item()
else: # binary case
unique_vals = rounded_int_value.unique(
sorted=False, return_counts=False, return_inverse=False)
is_binary = unique_vals.view(-1).size()[0] == 2
is_signed = (unique_vals < 0.).any().item()
sign_match = is_signed == self.signed
return is_int.item() and is_binary and sign_match
else:
return False
with torch.no_grad():
pre_round_int_value = self._pre_round_int_value
rounded_int_value = torch.round(pre_round_int_value)
max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value))
atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL
is_int = max_abs_diff < atol
if self.bit_width >= 2:
if self.signed:
is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all()
is_lower_b = (-2.0 ** (self.bit_width - 1) <= rounded_int_value).all()
else:
is_upper_b = (2.0 ** self.bit_width - 1 >= rounded_int_value).all()
is_lower_b = (0. <= rounded_int_value).all()
return (is_int & is_upper_b & is_lower_b).item()
else: # binary case
unique_vals = rounded_int_value.unique(
sorted=False, return_counts=False, return_inverse=False)
is_binary = unique_vals.view(-1).size()[0] == 2
is_signed = (unique_vals < 0.).any().item()
sign_match = is_signed == self.signed
return is_int.item() and is_binary and sign_match

@property
def device(self):
Expand All @@ -168,18 +147,18 @@ def detach_(self):
def detach(self):
return QuantTensor(
self.value.detach(),
self.scale.detach() if self.scale is not None else None,
self.zero_point.detach() if self.zero_point is not None else None,
self.bit_width.detach() if self.bit_width is not None else None,
self.scale.detach(),
self.zero_point.detach(),
self.bit_width.detach(),
self.signed,
self.training)

def contiguous(self):
return QuantTensor(
self.value.contiguous(),
self.scale.contiguous() if self.scale is not None else None,
self.zero_point.contiguous() if self.zero_point is not None else None,
self.bit_width.contiguous() if self.bit_width is not None else None,
self.scale.contiguous(),
self.zero_point.contiguous(),
self.bit_width.contiguous(),
self.signed,
self.training)

Expand Down Expand Up @@ -284,7 +263,7 @@ def cat(tensors, dim, out=None):
return tensors[0]
else:
first_qt = tensors[0]
if all([isinstance(qt, QuantTensor) and qt.is_not_none for qt in tensors]):
if all([isinstance(qt, QuantTensor) for qt in tensors]):
for qt in tensors[1:]:
first_qt.check_scaling_factors_same(qt)
first_qt.check_zero_points_same(qt)
Expand Down Expand Up @@ -364,7 +343,7 @@ def cpu(self, *args, **kwargs):
self.training)

def __add__(self, other):
if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none:
if isinstance(other, QuantTensor):
self.check_scaling_factors_same(other)
output_value = self.value + other.value
output_scale = (self.scale + other.scale) / 2
Expand Down Expand Up @@ -396,7 +375,7 @@ def __rmul__(self, other):
return self.__mul__(other)

def __mul__(self, other):
if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none:
if isinstance(other, QuantTensor):
output_value = self.value * other.value
output_scale = self.scale * other.scale
output_bit_width = self.bit_width + other.bit_width
Expand Down Expand Up @@ -426,7 +405,7 @@ def __str__(self):
return f"QuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})"

def __truediv__(self, other):
if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none:
if isinstance(other, QuantTensor):
output_tensor = self.value / other.value # Note, output tensor not guaranteed to pass self.is_valid()
max_int_denominator = 2 ** (other.bit_width - int(other.signed))
output_scale = self.scale / (other.scale * max_int_denominator)
Expand Down
5 changes: 1 addition & 4 deletions src/brevitas/quant_tensor/torch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ def decorator(func):

def quant_invariant_handler(fn, inp, *args, **kwargs):
out_value = fn(inp.value, *args, **kwargs)
if inp.is_not_none:
return inp.set(value=out_value)
else:
return out_value
return inp.set(value=out_value)


@implements(torch.flatten)
Expand Down
10 changes: 2 additions & 8 deletions tests/brevitas/nn/test_nn_quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,8 @@ def test_quant_wbiol(model_input, current_cases):

if kwargs['return_quant_tensor']:
assert isinstance(output, QuantTensor)
# Empty QuantTensor
if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \
kwargs['io_quant'] is None:
assert output.scale is None
assert output.bit_width is None
else: # "Full" QuantTensor
assert output.scale is not None
assert output.bit_width is not None
assert output.scale is not None
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
assert output.bit_width is not None
else:
assert isinstance(output, torch.Tensor)

Expand Down