Skip to content

Commit

Permalink
more exact bops computation
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Jan 25, 2024
1 parent 3721922 commit 100e35a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
14 changes: 14 additions & 0 deletions src/HGQ/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def input_bw(self):
except AssertionError:
return None

@property
def input_bw_exact(self):
assert len(self._inbound_nodes) <= 1, f"Layer {self.name} is reused {len(self._inbound_nodes)} times. This is not allowed."
try:
return self.last_layer.act_bw_exact.astype(np.float32)
except AssertionError:
return None


class HLayerBase(ABSBaseLayer):
"""Abstract base class for all layers in the library. Child classes should call post_build() after calling their build() method.
Expand Down Expand Up @@ -186,6 +194,12 @@ def act_bw(self):
bw = scale_grad(bw, tf.sqrt(1. / self.paq.degeneracy)) # type: ignore
return tf.broadcast_to(bw, (1,) + self.output_shape[1:])

@property
def act_bw_exact(self):
"""Returns the exact bitwidth of the pre-activation values. Non-differentiable. For post-training use."""
kn, int_bits, fb = self.paq.get_bits_exact(pos_only=self._relu_act)
return int_bits + fb + kn

@property
def fused_bias(self):
return self.bias
Expand Down
2 changes: 1 addition & 1 deletion src/HGQ/layers/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def jit_forward(self, x, training=None, record_minmax=None):
@property
def compute_exact_bops(self):
kernel_bw = tf.constant(self.kernel_bw_exact, dtype=tf.float32)
input_bw = self.input_bw # type: ignore
input_bw = self.input_bw_exact
bops = int(tf.reduce_sum(self.convolution_op(input_bw, kernel_bw)).numpy()) * int(self.parallel_factor.numpy()) / int(self.total_channels.numpy()) # type: ignore
self.bops.assign(tf.constant(bops, dtype=tf.float32))
return bops
Expand Down
8 changes: 4 additions & 4 deletions src/HGQ/layers/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def forward(self, x, training=None, record_minmax=None):
kernel_bw = self._kernel_bw(kq) # type: ignore
bops = tf.reduce_sum(tf.matmul(input_bw, kernel_bw))
self.bops.assign(bops)
bops = tf.cast(bops, tf.float32) * self.beta
bops = tf.cast(bops, tf.float32) * self.beta # type: ignore
self.add_loss(tf.convert_to_tensor(bops))
return a

Expand All @@ -71,8 +71,8 @@ def jit_forward(self, x, training=None, record_minmax=None):
@property
def compute_exact_bops(self):
kernel_bw = self.kernel_bw_exact
input_bw = self.input_bw.numpy() # type: ignore
bops = np.sum(np.matmul(input_bw, kernel_bw))
input_bw = self.input_bw_exact
bops = np.sum(np.matmul(input_bw, kernel_bw)) # type: ignore
self.bops.assign(tf.constant(bops, dtype=tf.float32))
return bops

Expand Down Expand Up @@ -156,7 +156,7 @@ def forward(self, x, training=None, record_minmax=None):
kernel_bw = self._kernel_bw(kq) # type: ignore
bops = tf.reduce_sum(tf.matmul(input_bw, kernel_bw))
self.bops.assign(bops)
bops = tf.cast(bops, tf.float32) * self.beta
bops = tf.cast(bops, tf.float32) * self.beta # type: ignore
self.add_loss(tf.convert_to_tensor(bops))
return a

Expand Down

0 comments on commit 100e35a

Please sign in to comment.