Skip to content

Commit

Permalink
New quantTensor
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 3, 2023
1 parent be76ca5 commit 944e77b
Show file tree
Hide file tree
Showing 8 changed files with 225 additions and 208 deletions.
2 changes: 1 addition & 1 deletion src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]):
# Avoid inplace operations on the input in case of forward hooks
if not torch._C._get_tracing_state():
if isinstance(inp, QuantTensor):
inp = inp.set(int_value=inp.int_value.rename(None))
inp = inp.set(qt_value=inp.qt_value.rename(None))
else:
inp = inp.rename(None)
return inp
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
if not self.return_quant_tensor or (output_scale is None and output_zero_point is None):
quant_output = output_tensor
else:
quant_output = QuantTensor(
quant_output = QuantTensor.from_fake_quantized(
output_tensor,
scale=output_scale,
zero_point=output_zero_point,
Expand Down
6 changes: 3 additions & 3 deletions src/brevitas/nn/quant_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def multi_head_attention(
# Mark dimensions through named tensors.
if not torch._C._get_tracing_state():
if isinstance(query, QuantTensor):
query.value.rename_('L', 'N', 'E')
query.qt_value.rename_('L', 'N', 'E')
else:
query.rename_('L', 'N', 'E')
# self-attention
Expand All @@ -426,7 +426,7 @@ def multi_head_attention(
if not torch._C._get_tracing_state():
for t in [query, key, value]:
if isinstance(t, QuantTensor):
t.value.rename_('L', 'N', 'E')
t.qt_value.rename_('L', 'N', 'E')
else:
t.rename_('L', 'N', 'E')
q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)
Expand Down Expand Up @@ -573,7 +573,7 @@ def multi_head_attention(
# Remove names to avoid errors un unsupported downstream ops
if not torch._C._get_tracing_state():
if isinstance(attn_output, QuantTensor):
attn_output.value.rename_(None)
attn_output.qt_value.rename_(None)
else:
attn_output.rename_(None)

Expand Down
8 changes: 4 additions & 4 deletions src/brevitas/nn/quant_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,10 +430,10 @@ def forward(self, inp, state):
else:
cell = self.cell
quant_outputs = cell(
quant_input.value,
quant_state.value,
quant_weight_ih.value,
quant_weight_hh.value,
_get_dequantize_tensor(quant_input),
_get_dequantize_tensor(quant_state),
_get_dequantize_tensor(quant_weight_ih),
_get_dequantize_tensor(quant_weight_hh),
quant_bias)
quant_output = self.pack_quant_outputs(quant_outputs)
quant_state = self.pack_quant_state(quant_outputs[-1], self.cell.output_quant)
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def check_tensors_same_ptr(tensor_list):
if hasattr(t, 'data_ptr'):
ptr = t.data_ptr()
pointers.append(ptr)
elif hasattr(t, 'value') and hasattr(t.value, 'data_ptr'):
pointers.append(t.value.data_ptr())
elif hasattr(t, 'qt_value') and hasattr(t.qt_value, 'data_ptr'):
pointers.append(t.qt_value.data_ptr())
else:
return False
return all(p == pointers[0] for p in pointers)
Expand Down
Loading

0 comments on commit 944e77b

Please sign in to comment.