Skip to content

Commit

Permalink
Tentative fix for QW<8 bit
Browse files Browse the repository at this point in the history
This fixes layout + runtime for QW<8 bit. Tested only on pointwise
and only on the special scenario of synthetic weights, for now.
  • Loading branch information
FrancescoConti committed Aug 21, 2024
1 parent 140fc2c commit 973cd49
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 10 deletions.
2 changes: 1 addition & 1 deletion neureka/hal/neureka_task.c
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ void neureka_task_set_strides(neureka_task_t *task, const uint32_t k_in,
if (task->kernel_shape == 1) { // 1x1
task->data.cfg.weights_stride.d0 = NEUREKA_WEIGHT_BANDWIDTH_BYTES_1x1;
task->data.cfg.weights_stride.d1 =
NEUREKA_WEIGHT_BANDWIDTH_BYTES_1x1 * num_k_in;
(NEUREKA_WEIGHT_BANDWIDTH_BYTES_1x1 / 8) * task->qw * num_k_in;
} else if (!task->depthwise) { // 3x3
task->data.cfg.weights_stride.d0 = NEUREKA_WEIGHT_BANDWIDTH_BYTES_3x3;
task->data.cfg.weights_stride.d1 =
Expand Down
11 changes: 2 additions & 9 deletions test/NeurekaMemoryLayout.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,8 @@ def weightEncode(
elif height == 1 and width == 1:
# (cout * cinMajor, Bits * cinSubtile)
weight = weight.reshape(-1, bits * cinSubtile)
# Pad only the last dimension to weight bandwidth size
# (-1, Weight Bandwidth)
weight = np.pad(
weight,
((0, 0), (0, NeurekaMemoryLayout._WEIGHT_BANDWIDTH_1x1 - weight.shape[-1])),
"constant",
constant_values=0,
)
weightBandwidthBytes = int(np.ceil(NeurekaMemoryLayout._WEIGHT_BANDWIDTH_1x1 / 8))
# No padding needed here
weightBandwidthBytes = int(np.ceil(bits * cinSubtile / 8))

# Prepare for packing
# (-1, Weight Bandwidth Bytes, 8)
Expand Down

0 comments on commit 973cd49

Please sign in to comment.