diff --git a/src/HGQ/layers/base.py b/src/HGQ/layers/base.py index cbdb144..6867c9c 100644 --- a/src/HGQ/layers/base.py +++ b/src/HGQ/layers/base.py @@ -31,6 +31,14 @@ def input_bw(self): except AssertionError: return None + @property + def input_bw_exact(self): + assert len(self._inbound_nodes) <= 1, f"Layer {self.name} is reused {len(self._inbound_nodes)} times. This is not allowed." + try: + return self.last_layer.act_bw_exact.astype(np.float32) + except AssertionError: + return None + class HLayerBase(ABSBaseLayer): """Abstract base class for all layers in the library. Child classes should call post_build() after calling their build() method. @@ -186,6 +194,12 @@ def act_bw(self): bw = scale_grad(bw, tf.sqrt(1. / self.paq.degeneracy)) # type: ignore return tf.broadcast_to(bw, (1,) + self.output_shape[1:]) + @property + def act_bw_exact(self): + """Returns the exact bitwidth of the pre-activation values. Non-differentiable. For post-training use.""" + kn, int_bits, fb = self.paq.get_bits_exact(pos_only=self._relu_act) + return int_bits + fb + kn + @property def fused_bias(self): return self.bias diff --git a/src/HGQ/layers/conv.py b/src/HGQ/layers/conv.py index 8a6701f..9a6d381 100644 --- a/src/HGQ/layers/conv.py +++ b/src/HGQ/layers/conv.py @@ -129,7 +129,7 @@ def jit_forward(self, x, training=None, record_minmax=None): @property def compute_exact_bops(self): kernel_bw = tf.constant(self.kernel_bw_exact, dtype=tf.float32) - input_bw = self.input_bw # type: ignore + input_bw = self.input_bw_exact bops = int(tf.reduce_sum(self.convolution_op(input_bw, kernel_bw)).numpy()) * int(self.parallel_factor.numpy()) / int(self.total_channels.numpy()) # type: ignore self.bops.assign(tf.constant(bops, dtype=tf.float32)) return bops diff --git a/src/HGQ/layers/dense.py b/src/HGQ/layers/dense.py index 04da7d9..ac6748a 100644 --- a/src/HGQ/layers/dense.py +++ b/src/HGQ/layers/dense.py @@ -51,7 +51,7 @@ def forward(self, x, training=None, record_minmax=None): kernel_bw = self._kernel_bw(kq) # type: ignore bops = tf.reduce_sum(tf.matmul(input_bw, kernel_bw)) self.bops.assign(bops) - bops = tf.cast(bops, tf.float32) * self.beta + bops = tf.cast(bops, tf.float32) * self.beta # type: ignore self.add_loss(tf.convert_to_tensor(bops)) return a @@ -71,8 +71,8 @@ def jit_forward(self, x, training=None, record_minmax=None): @property def compute_exact_bops(self): kernel_bw = self.kernel_bw_exact - input_bw = self.input_bw.numpy() # type: ignore - bops = np.sum(np.matmul(input_bw, kernel_bw)) + input_bw = self.input_bw_exact + bops = np.sum(np.matmul(input_bw, kernel_bw)) # type: ignore self.bops.assign(tf.constant(bops, dtype=tf.float32)) return bops @@ -156,7 +156,7 @@ def forward(self, x, training=None, record_minmax=None): kernel_bw = self._kernel_bw(kq) # type: ignore bops = tf.reduce_sum(tf.matmul(input_bw, kernel_bw)) self.bops.assign(bops) - bops = tf.cast(bops, tf.float32) * self.beta + bops = tf.cast(bops, tf.float32) * self.beta # type: ignore self.add_loss(tf.convert_to_tensor(bops)) return a